Source code for omni.isaac.lab_tasks.utils.wrappers.sb3

# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Wrapper to configure a :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv` instance to Stable-Baselines3 vectorized environment.

The following example shows how to wrap an environment for Stable-Baselines3:

.. code-block:: python

    from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper

    env = Sb3VecEnvWrapper(env)

"""

# needed to import for allowing type-hinting: torch.Tensor | dict[str, torch.Tensor]
from __future__ import annotations

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn  # noqa: F401
from typing import Any

from stable_baselines3.common.utils import constant_fn
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn

from omni.isaac.lab.envs import DirectRLEnv, ManagerBasedRLEnv

"""
Configuration Parser.
"""


[docs]def process_sb3_cfg(cfg: dict) -> dict: """Convert simple YAML types to Stable-Baselines classes/components. Args: cfg: A configuration dictionary. Returns: A dictionary containing the converted configuration. Reference: https://github.com/DLR-RM/rl-baselines3-zoo/blob/0e5eb145faefa33e7d79c7f8c179788574b20da5/utils/exp_manager.py#L358 """ def update_dict(hyperparams: dict[str, Any]) -> dict[str, Any]: for key, value in hyperparams.items(): if isinstance(value, dict): update_dict(value) else: if key in ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]: hyperparams[key] = eval(value) elif key in ["learning_rate", "clip_range", "clip_range_vf", "delta_std"]: if isinstance(value, str): _, initial_value = value.split("_") initial_value = float(initial_value) hyperparams[key] = lambda progress_remaining: progress_remaining * initial_value elif isinstance(value, (float, int)): # Negative value: ignore (ex: for clipping) if value < 0: continue hyperparams[key] = constant_fn(float(value)) else: raise ValueError(f"Invalid value for {key}: {hyperparams[key]}") return hyperparams # parse agent configuration and convert to classes return update_dict(cfg)
""" Vectorized environment wrapper. """
[docs]class Sb3VecEnvWrapper(VecEnv): """Wraps around Isaac Lab environment for Stable Baselines3. Isaac Sim internally implements a vectorized environment. However, since it is still considered a single environment instance, Stable Baselines tries to wrap around it using the :class:`DummyVecEnv`. This is only done if the environment is not inheriting from their :class:`VecEnv`. Thus, this class thinly wraps over the environment from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`. Note: While Stable-Baselines3 supports Gym 0.26+ API, their vectorized environment still uses the old API (i.e. it is closer to Gym 0.21). Thus, we implement the old API for the vectorized environment. We also add monitoring functionality that computes the un-discounted episode return and length. This information is added to the info dicts under key `episode`. In contrast to the Isaac Lab environment, stable-baselines expect the following: 1. numpy datatype for MDP signals 2. a list of info dicts for each sub-environment (instead of a dict) 3. when environment has terminated, the observations from the environment should correspond to the one after reset. The "real" final observation is passed using the info dicts under the key ``terminal_observation``. .. warning:: By the nature of physics stepping in Isaac Sim, it is not possible to forward the simulation buffers without performing a physics step. Thus, reset is performed inside the :meth:`step()` function after the actual physics step is taken. Thus, the returned observations for terminated environments is the one after the reset. .. caution:: This class must be the last wrapper in the wrapper chain. This is because the wrapper does not follow the :class:`gym.Wrapper` interface. Any subsequent wrappers will need to be modified to work with this wrapper. Reference: 1. https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html 2. https://stable-baselines3.readthedocs.io/en/master/common/monitor.html """
[docs] def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv): """Initialize the wrapper. Args: env: The environment to wrap around. Raises: ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`. """ # check that input is valid if not isinstance(env.unwrapped, ManagerBasedRLEnv) and not isinstance(env.unwrapped, DirectRLEnv): raise ValueError( "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:" f" {type(env)}" ) # initialize the wrapper self.env = env # collect common information self.num_envs = self.unwrapped.num_envs self.sim_device = self.unwrapped.device self.render_mode = self.unwrapped.render_mode # obtain gym spaces # note: stable-baselines3 does not like when we have unbounded action space so # we set it to some high value here. Maybe this is not general but something to think about. observation_space = self.unwrapped.single_observation_space["policy"] action_space = self.unwrapped.single_action_space if isinstance(action_space, gym.spaces.Box) and not action_space.is_bounded("both"): action_space = gym.spaces.Box(low=-100, high=100, shape=action_space.shape) # initialize vec-env VecEnv.__init__(self, self.num_envs, observation_space, action_space) # add buffer for logging episodic information self._ep_rew_buf = torch.zeros(self.num_envs, device=self.sim_device) self._ep_len_buf = torch.zeros(self.num_envs, device=self.sim_device)
def __str__(self): """Returns the wrapper name and the :attr:`env` representation string.""" return f"<{type(self).__name__}{self.env}>" def __repr__(self): """Returns the string representation of the wrapper.""" return str(self) """ Properties -- Gym.Wrapper """
[docs] @classmethod def class_name(cls) -> str: """Returns the class name of the wrapper.""" return cls.__name__
@property def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv: """Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. """ return self.env.unwrapped """ Properties """
[docs] def get_episode_rewards(self) -> list[float]: """Returns the rewards of all the episodes.""" return self._ep_rew_buf.cpu().tolist()
[docs] def get_episode_lengths(self) -> list[int]: """Returns the number of time-steps of all the episodes.""" return self._ep_len_buf.cpu().tolist()
""" Operations - MDP """ def seed(self, seed: int | None = None) -> list[int | None]: # noqa: D102 return [self.unwrapped.seed(seed)] * self.unwrapped.num_envs def reset(self) -> VecEnvObs: # noqa: D102 obs_dict, _ = self.env.reset() # reset episodic information buffers self._ep_rew_buf.zero_() self._ep_len_buf.zero_() # convert data types to numpy depending on backend return self._process_obs(obs_dict) def step_async(self, actions): # noqa: D102 # convert input to numpy array if not isinstance(actions, torch.Tensor): actions = np.asarray(actions) actions = torch.from_numpy(actions).to(device=self.sim_device, dtype=torch.float32) else: actions = actions.to(device=self.sim_device, dtype=torch.float32) # convert to tensor self._async_actions = actions def step_wait(self) -> VecEnvStepReturn: # noqa: D102 # record step information obs_dict, rew, terminated, truncated, extras = self.env.step(self._async_actions) # update episode un-discounted return and length self._ep_rew_buf += rew self._ep_len_buf += 1 # compute reset ids dones = terminated | truncated reset_ids = (dones > 0).nonzero(as_tuple=False) # convert data types to numpy depending on backend # note: ManagerBasedRLEnv uses torch backend (by default). obs = self._process_obs(obs_dict) rew = rew.detach().cpu().numpy() terminated = terminated.detach().cpu().numpy() truncated = truncated.detach().cpu().numpy() dones = dones.detach().cpu().numpy() # convert extra information to list of dicts infos = self._process_extras(obs, terminated, truncated, extras, reset_ids) # reset info for terminated environments self._ep_rew_buf[reset_ids] = 0 self._ep_len_buf[reset_ids] = 0 return obs, rew, dones, infos def close(self): # noqa: D102 self.env.close() def get_attr(self, attr_name, indices=None): # noqa: D102 # resolve indices if indices is None: indices = slice(None) num_indices = self.num_envs else: num_indices = len(indices) # obtain attribute value attr_val = getattr(self.env, attr_name) # return the value if not isinstance(attr_val, torch.Tensor): return [attr_val] * num_indices else: return attr_val[indices].detach().cpu().numpy() def set_attr(self, attr_name, value, indices=None): # noqa: D102 raise NotImplementedError("Setting attributes is not supported.") def env_method(self, method_name: str, *method_args, indices=None, **method_kwargs): # noqa: D102 if method_name == "render": # gymnasium does not support changing render mode at runtime return self.env.render() else: # this isn't properly implemented but it is not necessary. # mostly done for completeness. env_method = getattr(self.env, method_name) return env_method(*method_args, indices=indices, **method_kwargs) def env_is_wrapped(self, wrapper_class, indices=None): # noqa: D102 raise NotImplementedError("Checking if environment is wrapped is not supported.") def get_images(self): # noqa: D102 raise NotImplementedError("Getting images is not supported.") """ Helper functions. """ def _process_obs(self, obs_dict: torch.Tensor | dict[str, torch.Tensor]) -> np.ndarray | dict[str, np.ndarray]: """Convert observations into NumPy data type.""" # Sb3 doesn't support asymmetric observation spaces, so we only use "policy" obs = obs_dict["policy"] # note: ManagerBasedRLEnv uses torch backend (by default). if isinstance(obs, dict): for key, value in obs.items(): obs[key] = value.detach().cpu().numpy() elif isinstance(obs, torch.Tensor): obs = obs.detach().cpu().numpy() else: raise NotImplementedError(f"Unsupported data type: {type(obs)}") return obs def _process_extras( self, obs: np.ndarray, terminated: np.ndarray, truncated: np.ndarray, extras: dict, reset_ids: np.ndarray ) -> list[dict[str, Any]]: """Convert miscellaneous information into dictionary for each sub-environment.""" # create empty list of dictionaries to fill infos: list[dict[str, Any]] = [dict.fromkeys(extras.keys()) for _ in range(self.num_envs)] # fill-in information for each sub-environment # note: This loop becomes slow when number of environments is large. for idx in range(self.num_envs): # fill-in episode monitoring info if idx in reset_ids: infos[idx]["episode"] = dict() infos[idx]["episode"]["r"] = float(self._ep_rew_buf[idx]) infos[idx]["episode"]["l"] = float(self._ep_len_buf[idx]) else: infos[idx]["episode"] = None # fill-in bootstrap information infos[idx]["TimeLimit.truncated"] = truncated[idx] and not terminated[idx] # fill-in information from extras for key, value in extras.items(): # 1. remap extra episodes information safely # 2. for others just store their values if key == "log": # only log this data for episodes that are terminated if infos[idx]["episode"] is not None: for sub_key, sub_value in value.items(): infos[idx]["episode"][sub_key] = sub_value else: infos[idx][key] = value[idx] # add information about terminal observation separately if idx in reset_ids: # extract terminal observations if isinstance(obs, dict): terminal_obs = dict.fromkeys(obs.keys()) for key, value in obs.items(): terminal_obs[key] = value[idx] else: terminal_obs = obs[idx] # add info to dict infos[idx]["terminal_observation"] = terminal_obs else: infos[idx]["terminal_observation"] = None # return list of dictionaries return infos