# Copyright (c) 2022-2025, The Isaac Lab Project Developers.# All rights reserved.## SPDX-License-Identifier: BSD-3-Clause"""Command manager for generating and updating commands."""from__future__importannotationsimportinspectimporttorchimportweakreffromabcimportabstractmethodfromcollections.abcimportSequencefromprettytableimportPrettyTablefromtypingimportTYPE_CHECKINGimportomni.kit.appfrom.manager_baseimportManagerBase,ManagerTermBasefrom.manager_term_cfgimportCommandTermCfgifTYPE_CHECKING:fromisaaclab.envsimportManagerBasedRLEnv
[docs]classCommandTerm(ManagerTermBase):"""The base class for implementing a command term. A command term is used to generate commands for goal-conditioned tasks. For example, in the case of a goal-conditioned navigation task, the command term can be used to generate a target position for the robot to navigate to. It implements a resampling mechanism that allows the command to be resampled at a fixed frequency. The resampling frequency can be specified in the configuration object. Additionally, it is possible to assign a visualization function to the command term that can be used to visualize the command in the simulator. """def__init__(self,cfg:CommandTermCfg,env:ManagerBasedRLEnv):"""Initialize the command generator class. Args: cfg: The configuration parameters for the command generator. env: The environment object. """super().__init__(cfg,env)# create buffers to store the command# -- metrics that can be used for loggingself.metrics=dict()# -- time left before resamplingself.time_left=torch.zeros(self.num_envs,device=self.device)# -- counter for the number of times the command has been resampled within the current episodeself.command_counter=torch.zeros(self.num_envs,device=self.device,dtype=torch.long)# add handle for debug visualization (this is set to a valid handle inside set_debug_vis)self._debug_vis_handle=None# set initial state of debug visualizationself.set_debug_vis(self.cfg.debug_vis)def__del__(self):"""Unsubscribe from the callbacks."""ifself._debug_vis_handle:self._debug_vis_handle.unsubscribe()self._debug_vis_handle=None""" Properties """@property@abstractmethoddefcommand(self)->torch.Tensor:"""The command tensor. Shape is (num_envs, command_dim)."""raiseNotImplementedError@propertydefhas_debug_vis_implementation(self)->bool:"""Whether the command generator has a debug visualization implemented."""# check if function raises NotImplementedErrorsource_code=inspect.getsource(self._set_debug_vis_impl)return"NotImplementedError"notinsource_code""" Operations. """
[docs]defset_debug_vis(self,debug_vis:bool)->bool:"""Sets whether to visualize the command data. Args: debug_vis: Whether to visualize the command data. Returns: Whether the debug visualization was successfully set. False if the command generator does not support debug visualization. """# check if debug visualization is supportedifnotself.has_debug_vis_implementation:returnFalse# toggle debug visualization objectsself._set_debug_vis_impl(debug_vis)# toggle debug visualization handlesifdebug_vis:# create a subscriber for the post update event if it doesn't existifself._debug_vis_handleisNone:app_interface=omni.kit.app.get_app_interface()self._debug_vis_handle=app_interface.get_post_update_event_stream().create_subscription_to_pop(lambdaevent,obj=weakref.proxy(self):obj._debug_vis_callback(event))else:# remove the subscriber if it existsifself._debug_vis_handleisnotNone:self._debug_vis_handle.unsubscribe()self._debug_vis_handle=None# return successreturnTrue
[docs]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,float]:"""Reset the command generator and log metrics. This function resets the command counter and resamples the command. It should be called at the beginning of each episode. Args: env_ids: The list of environment IDs to reset. Defaults to None. Returns: A dictionary containing the information to log under the "{name}" key. """# resolve the environment IDsifenv_idsisNone:env_ids=slice(None)# add logging metricsextras={}formetric_name,metric_valueinself.metrics.items():# compute the mean metric valueextras[metric_name]=torch.mean(metric_value[env_ids]).item()# reset the metric valuemetric_value[env_ids]=0.0# set the command counter to zeroself.command_counter[env_ids]=0# resample the commandself._resample(env_ids)returnextras
[docs]defcompute(self,dt:float):"""Compute the command. Args: dt: The time step passed since the last call to compute. """# update the metrics based on current stateself._update_metrics()# reduce the time left before resamplingself.time_left-=dt# resample the command if necessaryresample_env_ids=(self.time_left<=0.0).nonzero().flatten()iflen(resample_env_ids)>0:self._resample(resample_env_ids)# update the commandself._update_command()
""" Helper functions. """def_resample(self,env_ids:Sequence[int]):"""Resample the command. This function resamples the command and time for which the command is applied for the specified environment indices. Args: env_ids: The list of environment IDs to resample. """iflen(env_ids)!=0:# resample the time left before resamplingself.time_left[env_ids]=self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)# increment the command counterself.command_counter[env_ids]+=1# resample the commandself._resample_command(env_ids)""" Implementation specific functions. """@abstractmethoddef_update_metrics(self):"""Update the metrics based on the current state."""raiseNotImplementedError@abstractmethoddef_resample_command(self,env_ids:Sequence[int]):"""Resample the command for the specified environments."""raiseNotImplementedError@abstractmethoddef_update_command(self):"""Update the command based on the current state."""raiseNotImplementedErrordef_set_debug_vis_impl(self,debug_vis:bool):"""Set debug visualization into visualization objects. This function is responsible for creating the visualization objects if they don't exist and input ``debug_vis`` is True. If the visualization objects exist, the function should set their visibility into the stage. """raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")def_debug_vis_callback(self,event):"""Callback for debug visualization. This function calls the visualization objects and sets the data to visualize into them. """raiseNotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
[docs]classCommandManager(ManagerBase):"""Manager for generating commands. The command manager is used to generate commands for an agent to execute. It makes it convenient to switch between different command generation strategies within the same environment. For instance, in an environment consisting of a quadrupedal robot, the command to it could be a velocity command or position command. By keeping the command generation logic separate from the environment, it is easy to switch between different command generation strategies. The command terms are implemented as classes that inherit from the :class:`CommandTerm` class. Each command generator term should also have a corresponding configuration class that inherits from the :class:`CommandTermCfg` class. """_env:ManagerBasedRLEnv"""The environment instance."""
[docs]def__init__(self,cfg:object,env:ManagerBasedRLEnv):"""Initialize the command manager. Args: cfg: The configuration object or dictionary (``dict[str, CommandTermCfg]``). env: The environment instance. """# create buffers to parse and store termsself._terms:dict[str,CommandTerm]=dict()# call the base class constructor (this prepares the terms)super().__init__(cfg,env)# store the commandsself._commands=dict()ifself.cfg:self.cfg.debug_vis=Falseforterminself._terms.values():self.cfg.debug_vis|=term.cfg.debug_vis
def__str__(self)->str:"""Returns: A string representation for the command manager."""msg=f"<CommandManager> contains {len(self._terms.values())} active terms.\n"# create table for term informationtable=PrettyTable()table.title="Active Command Terms"table.field_names=["Index","Name","Type"]# set alignment of table columnstable.align["Name"]="l"# add info on each termforindex,(name,term)inenumerate(self._terms.items()):table.add_row([index,name,term.__class__.__name__])# convert table to stringmsg+=table.get_string()msg+="\n"returnmsg""" Properties. """@propertydefactive_terms(self)->list[str]:"""Name of active command terms."""returnlist(self._terms.keys())@propertydefhas_debug_vis_implementation(self)->bool:"""Whether the command terms have debug visualization implemented."""# check if function raises NotImplementedErrorhas_debug_vis=Falseforterminself._terms.values():has_debug_vis|=term.has_debug_vis_implementationreturnhas_debug_vis""" Operations. """
[docs]defget_active_iterable_terms(self,env_idx:int)->Sequence[tuple[str,Sequence[float]]]:"""Returns the active terms as iterable sequence of tuples. The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. Args: env_idx: The specific environment to pull the active terms from. Returns: The active terms. """terms=[]idx=0forname,terminself._terms.items():terms.append((name,term.command[env_idx].cpu().tolist()))idx+=term.command.shape[1]returnterms
[docs]defset_debug_vis(self,debug_vis:bool):"""Sets whether to visualize the command data. Args: debug_vis: Whether to visualize the command data. Returns: Whether the debug visualization was successfully set. False if the command generator does not support debug visualization. """forterminself._terms.values():term.set_debug_vis(debug_vis)
[docs]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,torch.Tensor]:"""Reset the command terms and log their metrics. This function resets the command counter and resamples the command for each term. It should be called at the beginning of each episode. Args: env_ids: The list of environment IDs to reset. Defaults to None. Returns: A dictionary containing the information to log under the "Metrics/{term_name}/{metric_name}" key. """# resolve environment idsifenv_idsisNone:env_ids=slice(None)# store informationextras={}forname,terminself._terms.items():# reset the command termmetrics=term.reset(env_ids=env_ids)# compute the mean metric valueformetric_name,metric_valueinmetrics.items():extras[f"Metrics/{name}/{metric_name}"]=metric_value# return logged informationreturnextras
[docs]defcompute(self,dt:float):"""Updates the commands. This function calls each command term managed by the class. Args: dt: The time-step interval of the environment. """# iterate over all the command termsforterminself._terms.values():# compute term's valueterm.compute(dt)
[docs]defget_command(self,name:str)->torch.Tensor:"""Returns the command for the specified command term. Args: name: The name of the command term. Returns: The command tensor of the specified command term. """returnself._terms[name].command
[docs]defget_term(self,name:str)->CommandTerm:"""Returns the command term with the specified name. Args: name: The name of the command term. Returns: The command term with the specified name. """returnself._terms[name]
""" Helper functions. """def_prepare_terms(self):# check if config is dict alreadyifisinstance(self.cfg,dict):cfg_items=self.cfg.items()else:cfg_items=self.cfg.__dict__.items()# iterate over all the termsforterm_name,term_cfgincfg_items:# check for non configifterm_cfgisNone:continue# check for valid config typeifnotisinstance(term_cfg,CommandTermCfg):raiseTypeError(f"Configuration for the term '{term_name}' is not of type CommandTermCfg."f" Received: '{type(term_cfg)}'.")# create the action termterm=term_cfg.class_type(term_cfg,self._env)# sanity check if term is valid typeifnotisinstance(term,CommandTerm):raiseTypeError(f"Returned object for the term '{term_name}' is not of type CommandType.")# add class to dictself._terms[term_name]=term