# Copyright (c) 2022-2025, The Isaac Lab Project Developers.# All rights reserved.## SPDX-License-Identifier: BSD-3-Clause"""Wrapper to configure an environment instance to Stable-Baselines3 vectorized environment.The following example shows how to wrap an environment for Stable-Baselines3:.. code-block:: python from isaaclab_rl.sb3 import Sb3VecEnvWrapper env = Sb3VecEnvWrapper(env)"""# needed to import for allowing type-hinting: torch.Tensor | dict[str, torch.Tensor]from__future__importannotationsimportgymnasiumasgymimportnumpyasnpimporttorchimporttorch.nnasnn# noqa: F401fromtypingimportAnyfromstable_baselines3.common.utilsimportconstant_fnfromstable_baselines3.common.vec_env.base_vec_envimportVecEnv,VecEnvObs,VecEnvStepReturnfromisaaclab.envsimportDirectRLEnv,ManagerBasedRLEnv"""Configuration Parser."""
[docs]defprocess_sb3_cfg(cfg:dict)->dict:"""Convert simple YAML types to Stable-Baselines classes/components. Args: cfg: A configuration dictionary. Returns: A dictionary containing the converted configuration. Reference: https://github.com/DLR-RM/rl-baselines3-zoo/blob/0e5eb145faefa33e7d79c7f8c179788574b20da5/utils/exp_manager.py#L358 """defupdate_dict(hyperparams:dict[str,Any])->dict[str,Any]:forkey,valueinhyperparams.items():ifisinstance(value,dict):update_dict(value)else:ifkeyin["policy_kwargs","replay_buffer_class","replay_buffer_kwargs"]:hyperparams[key]=eval(value)elifkeyin["learning_rate","clip_range","clip_range_vf","delta_std"]:ifisinstance(value,str):_,initial_value=value.split("_")initial_value=float(initial_value)hyperparams[key]=lambdaprogress_remaining:progress_remaining*initial_valueelifisinstance(value,(float,int)):# Negative value: ignore (ex: for clipping)ifvalue<0:continuehyperparams[key]=constant_fn(float(value))else:raiseValueError(f"Invalid value for {key}: {hyperparams[key]}")returnhyperparams# parse agent configuration and convert to classesreturnupdate_dict(cfg)
"""Vectorized environment wrapper."""
[docs]classSb3VecEnvWrapper(VecEnv):"""Wraps around Isaac Lab environment for Stable Baselines3. Isaac Sim internally implements a vectorized environment. However, since it is still considered a single environment instance, Stable Baselines tries to wrap around it using the :class:`DummyVecEnv`. This is only done if the environment is not inheriting from their :class:`VecEnv`. Thus, this class thinly wraps over the environment from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`. Note: While Stable-Baselines3 supports Gym 0.26+ API, their vectorized environment still uses the old API (i.e. it is closer to Gym 0.21). Thus, we implement the old API for the vectorized environment. We also add monitoring functionality that computes the un-discounted episode return and length. This information is added to the info dicts under key `episode`. In contrast to the Isaac Lab environment, stable-baselines expect the following: 1. numpy datatype for MDP signals 2. a list of info dicts for each sub-environment (instead of a dict) 3. when environment has terminated, the observations from the environment should correspond to the one after reset. The "real" final observation is passed using the info dicts under the key ``terminal_observation``. .. warning:: By the nature of physics stepping in Isaac Sim, it is not possible to forward the simulation buffers without performing a physics step. Thus, reset is performed inside the :meth:`step()` function after the actual physics step is taken. Thus, the returned observations for terminated environments is the one after the reset. .. caution:: This class must be the last wrapper in the wrapper chain. This is because the wrapper does not follow the :class:`gym.Wrapper` interface. Any subsequent wrappers will need to be modified to work with this wrapper. Reference: 1. https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html 2. https://stable-baselines3.readthedocs.io/en/master/common/monitor.html """
[docs]def__init__(self,env:ManagerBasedRLEnv|DirectRLEnv):"""Initialize the wrapper. Args: env: The environment to wrap around. Raises: ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`. """# check that input is validifnotisinstance(env.unwrapped,ManagerBasedRLEnv)andnotisinstance(env.unwrapped,DirectRLEnv):raiseValueError("The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:"f" {type(env)}")# initialize the wrapperself.env=env# collect common informationself.num_envs=self.unwrapped.num_envsself.sim_device=self.unwrapped.deviceself.render_mode=self.unwrapped.render_mode# obtain gym spaces# note: stable-baselines3 does not like when we have unbounded action space so# we set it to some high value here. Maybe this is not general but something to think about.observation_space=self.unwrapped.single_observation_space["policy"]action_space=self.unwrapped.single_action_spaceifisinstance(action_space,gym.spaces.Box)andnotaction_space.is_bounded("both"):action_space=gym.spaces.Box(low=-100,high=100,shape=action_space.shape)# initialize vec-envVecEnv.__init__(self,self.num_envs,observation_space,action_space)# add buffer for logging episodic informationself._ep_rew_buf=torch.zeros(self.num_envs,device=self.sim_device)self._ep_len_buf=torch.zeros(self.num_envs,device=self.sim_device)
def__str__(self):"""Returns the wrapper name and the :attr:`env` representation string."""returnf"<{type(self).__name__}{self.env}>"def__repr__(self):"""Returns the string representation of the wrapper."""returnstr(self)""" Properties -- Gym.Wrapper """
[docs]@classmethoddefclass_name(cls)->str:"""Returns the class name of the wrapper."""returncls.__name__
@propertydefunwrapped(self)->ManagerBasedRLEnv|DirectRLEnv:"""Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. """returnself.env.unwrapped""" Properties """
[docs]defget_episode_rewards(self)->list[float]:"""Returns the rewards of all the episodes."""returnself._ep_rew_buf.cpu().tolist()
[docs]defget_episode_lengths(self)->list[int]:"""Returns the number of time-steps of all the episodes."""returnself._ep_len_buf.cpu().tolist()
""" Operations - MDP """defseed(self,seed:int|None=None)->list[int|None]:# noqa: D102return[self.unwrapped.seed(seed)]*self.unwrapped.num_envsdefreset(self)->VecEnvObs:# noqa: D102obs_dict,_=self.env.reset()# reset episodic information buffersself._ep_rew_buf.zero_()self._ep_len_buf.zero_()# convert data types to numpy depending on backendreturnself._process_obs(obs_dict)defstep_async(self,actions):# noqa: D102# convert input to numpy arrayifnotisinstance(actions,torch.Tensor):actions=np.asarray(actions)actions=torch.from_numpy(actions).to(device=self.sim_device,dtype=torch.float32)else:actions=actions.to(device=self.sim_device,dtype=torch.float32)# convert to tensorself._async_actions=actionsdefstep_wait(self)->VecEnvStepReturn:# noqa: D102# record step informationobs_dict,rew,terminated,truncated,extras=self.env.step(self._async_actions)# update episode un-discounted return and lengthself._ep_rew_buf+=rewself._ep_len_buf+=1# compute reset idsdones=terminated|truncatedreset_ids=(dones>0).nonzero(as_tuple=False)# convert data types to numpy depending on backend# note: ManagerBasedRLEnv uses torch backend (by default).obs=self._process_obs(obs_dict)rew=rew.detach().cpu().numpy()terminated=terminated.detach().cpu().numpy()truncated=truncated.detach().cpu().numpy()dones=dones.detach().cpu().numpy()# convert extra information to list of dictsinfos=self._process_extras(obs,terminated,truncated,extras,reset_ids)# reset info for terminated environmentsself._ep_rew_buf[reset_ids]=0self._ep_len_buf[reset_ids]=0returnobs,rew,dones,infosdefclose(self):# noqa: D102self.env.close()defget_attr(self,attr_name,indices=None):# noqa: D102# resolve indicesifindicesisNone:indices=slice(None)num_indices=self.num_envselse:num_indices=len(indices)# obtain attribute valueattr_val=getattr(self.env,attr_name)# return the valueifnotisinstance(attr_val,torch.Tensor):return[attr_val]*num_indiceselse:returnattr_val[indices].detach().cpu().numpy()defset_attr(self,attr_name,value,indices=None):# noqa: D102raiseNotImplementedError("Setting attributes is not supported.")defenv_method(self,method_name:str,*method_args,indices=None,**method_kwargs):# noqa: D102ifmethod_name=="render":# gymnasium does not support changing render mode at runtimereturnself.env.render()else:# this isn't properly implemented but it is not necessary.# mostly done for completeness.env_method=getattr(self.env,method_name)returnenv_method(*method_args,indices=indices,**method_kwargs)defenv_is_wrapped(self,wrapper_class,indices=None):# noqa: D102raiseNotImplementedError("Checking if environment is wrapped is not supported.")defget_images(self):# noqa: D102raiseNotImplementedError("Getting images is not supported.")""" Helper functions. """def_process_obs(self,obs_dict:torch.Tensor|dict[str,torch.Tensor])->np.ndarray|dict[str,np.ndarray]:"""Convert observations into NumPy data type."""# Sb3 doesn't support asymmetric observation spaces, so we only use "policy"obs=obs_dict["policy"]# note: ManagerBasedRLEnv uses torch backend (by default).ifisinstance(obs,dict):forkey,valueinobs.items():obs[key]=value.detach().cpu().numpy()elifisinstance(obs,torch.Tensor):obs=obs.detach().cpu().numpy()else:raiseNotImplementedError(f"Unsupported data type: {type(obs)}")returnobsdef_process_extras(self,obs:np.ndarray,terminated:np.ndarray,truncated:np.ndarray,extras:dict,reset_ids:np.ndarray)->list[dict[str,Any]]:"""Convert miscellaneous information into dictionary for each sub-environment."""# create empty list of dictionaries to fillinfos:list[dict[str,Any]]=[dict.fromkeys(extras.keys())for_inrange(self.num_envs)]# fill-in information for each sub-environment# note: This loop becomes slow when number of environments is large.foridxinrange(self.num_envs):# fill-in episode monitoring infoifidxinreset_ids:infos[idx]["episode"]=dict()infos[idx]["episode"]["r"]=float(self._ep_rew_buf[idx])infos[idx]["episode"]["l"]=float(self._ep_len_buf[idx])else:infos[idx]["episode"]=None# fill-in bootstrap informationinfos[idx]["TimeLimit.truncated"]=truncated[idx]andnotterminated[idx]# fill-in information from extrasforkey,valueinextras.items():# 1. remap extra episodes information safely# 2. for others just store their valuesifkey=="log":# only log this data for episodes that are terminatedifinfos[idx]["episode"]isnotNone:forsub_key,sub_valueinvalue.items():infos[idx]["episode"][sub_key]=sub_valueelse:infos[idx][key]=value[idx]# add information about terminal observation separatelyifidxinreset_ids:# extract terminal observationsifisinstance(obs,dict):terminal_obs=dict.fromkeys(obs.keys())forkey,valueinobs.items():terminal_obs[key]=value[idx]else:terminal_obs=obs[idx]# add info to dictinfos[idx]["terminal_observation"]=terminal_obselse:infos[idx]["terminal_observation"]=None# return list of dictionariesreturninfos