Creating a Direct Workflow RL Environment#

In addition to the envs.ManagerBasedRLEnv class, which encourages the use of configuration classes for more modular environments, the DirectRLEnv class allows for more direct control in the scripting of environment.

Instead of using Manager classes for defining rewards and observations, the direct workflow tasks implement the full reward and observation functions directly in the task script. This allows for more control in the implementation of the methods, such as using pytorch jit features, and provides a less abstracted framework that makes it easier to find the various pieces of code.

In this tutorial, we will configure the cartpole environment using the direct workflow implementation to create a task for balancing the pole upright. We will learn how to specify the task using by implementing functions for scene creation, actions, resets, rewards and observations.

The Code#

For this tutorial, we use the cartpole environment defined in omni.isaac.lab_tasks.direct.cartpole module.

Code for cartpole_env.py
  1# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
  2# All rights reserved.
  3#
  4# SPDX-License-Identifier: BSD-3-Clause
  5
  6from __future__ import annotations
  7
  8import math
  9import torch
 10from collections.abc import Sequence
 11
 12from omni.isaac.lab_assets.cartpole import CARTPOLE_CFG
 13
 14import omni.isaac.lab.sim as sim_utils
 15from omni.isaac.lab.assets import Articulation, ArticulationCfg
 16from omni.isaac.lab.envs import DirectRLEnv, DirectRLEnvCfg
 17from omni.isaac.lab.scene import InteractiveSceneCfg
 18from omni.isaac.lab.sim import SimulationCfg
 19from omni.isaac.lab.sim.spawners.from_files import GroundPlaneCfg, spawn_ground_plane
 20from omni.isaac.lab.utils import configclass
 21from omni.isaac.lab.utils.math import sample_uniform
 22
 23
 24@configclass
 25class CartpoleEnvCfg(DirectRLEnvCfg):
 26    # env
 27    decimation = 2
 28    episode_length_s = 5.0
 29    action_scale = 100.0  # [N]
 30    num_actions = 1
 31    num_observations = 4
 32    num_states = 0
 33
 34    # simulation
 35    sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation)
 36
 37    # robot
 38    robot_cfg: ArticulationCfg = CARTPOLE_CFG.replace(prim_path="/World/envs/env_.*/Robot")
 39    cart_dof_name = "slider_to_cart"
 40    pole_dof_name = "cart_to_pole"
 41
 42    # scene
 43    scene: InteractiveSceneCfg = InteractiveSceneCfg(num_envs=4096, env_spacing=4.0, replicate_physics=True)
 44
 45    # reset
 46    max_cart_pos = 3.0  # the cart is reset if it exceeds that position [m]
 47    initial_pole_angle_range = [-0.25, 0.25]  # the range in which the pole angle is sampled from on reset [rad]
 48
 49    # reward scales
 50    rew_scale_alive = 1.0
 51    rew_scale_terminated = -2.0
 52    rew_scale_pole_pos = -1.0
 53    rew_scale_cart_vel = -0.01
 54    rew_scale_pole_vel = -0.005
 55
 56
 57class CartpoleEnv(DirectRLEnv):
 58    cfg: CartpoleEnvCfg
 59
 60    def __init__(self, cfg: CartpoleEnvCfg, render_mode: str | None = None, **kwargs):
 61        super().__init__(cfg, render_mode, **kwargs)
 62
 63        self._cart_dof_idx, _ = self.cartpole.find_joints(self.cfg.cart_dof_name)
 64        self._pole_dof_idx, _ = self.cartpole.find_joints(self.cfg.pole_dof_name)
 65        self.action_scale = self.cfg.action_scale
 66
 67        self.joint_pos = self.cartpole.data.joint_pos
 68        self.joint_vel = self.cartpole.data.joint_vel
 69
 70    def _setup_scene(self):
 71        self.cartpole = Articulation(self.cfg.robot_cfg)
 72        # add ground plane
 73        spawn_ground_plane(prim_path="/World/ground", cfg=GroundPlaneCfg())
 74        # clone, filter, and replicate
 75        self.scene.clone_environments(copy_from_source=False)
 76        self.scene.filter_collisions(global_prim_paths=[])
 77        # add articultion to scene
 78        self.scene.articulations["cartpole"] = self.cartpole
 79        # add lights
 80        light_cfg = sim_utils.DomeLightCfg(intensity=2000.0, color=(0.75, 0.75, 0.75))
 81        light_cfg.func("/World/Light", light_cfg)
 82
 83    def _pre_physics_step(self, actions: torch.Tensor) -> None:
 84        self.actions = self.action_scale * actions.clone()
 85
 86    def _apply_action(self) -> None:
 87        self.cartpole.set_joint_effort_target(self.actions, joint_ids=self._cart_dof_idx)
 88
 89    def _get_observations(self) -> dict:
 90        obs = torch.cat(
 91            (
 92                self.joint_pos[:, self._pole_dof_idx[0]].unsqueeze(dim=1),
 93                self.joint_vel[:, self._pole_dof_idx[0]].unsqueeze(dim=1),
 94                self.joint_pos[:, self._cart_dof_idx[0]].unsqueeze(dim=1),
 95                self.joint_vel[:, self._cart_dof_idx[0]].unsqueeze(dim=1),
 96            ),
 97            dim=-1,
 98        )
 99        observations = {"policy": obs}
100        return observations
101
102    def _get_rewards(self) -> torch.Tensor:
103        total_reward = compute_rewards(
104            self.cfg.rew_scale_alive,
105            self.cfg.rew_scale_terminated,
106            self.cfg.rew_scale_pole_pos,
107            self.cfg.rew_scale_cart_vel,
108            self.cfg.rew_scale_pole_vel,
109            self.joint_pos[:, self._pole_dof_idx[0]],
110            self.joint_vel[:, self._pole_dof_idx[0]],
111            self.joint_pos[:, self._cart_dof_idx[0]],
112            self.joint_vel[:, self._cart_dof_idx[0]],
113            self.reset_terminated,
114        )
115        return total_reward
116
117    def _get_dones(self) -> tuple[torch.Tensor, torch.Tensor]:
118        self.joint_pos = self.cartpole.data.joint_pos
119        self.joint_vel = self.cartpole.data.joint_vel
120
121        time_out = self.episode_length_buf >= self.max_episode_length - 1
122        out_of_bounds = torch.any(torch.abs(self.joint_pos[:, self._cart_dof_idx]) > self.cfg.max_cart_pos, dim=1)
123        out_of_bounds = out_of_bounds | torch.any(torch.abs(self.joint_pos[:, self._pole_dof_idx]) > math.pi / 2, dim=1)
124        return out_of_bounds, time_out
125
126    def _reset_idx(self, env_ids: Sequence[int] | None):
127        if env_ids is None:
128            env_ids = self.cartpole._ALL_INDICES
129        super()._reset_idx(env_ids)
130
131        joint_pos = self.cartpole.data.default_joint_pos[env_ids]
132        joint_pos[:, self._pole_dof_idx] += sample_uniform(
133            self.cfg.initial_pole_angle_range[0] * math.pi,
134            self.cfg.initial_pole_angle_range[1] * math.pi,
135            joint_pos[:, self._pole_dof_idx].shape,
136            joint_pos.device,
137        )
138        joint_vel = self.cartpole.data.default_joint_vel[env_ids]
139
140        default_root_state = self.cartpole.data.default_root_state[env_ids]
141        default_root_state[:, :3] += self.scene.env_origins[env_ids]
142
143        self.joint_pos[env_ids] = joint_pos
144        self.joint_vel[env_ids] = joint_vel
145
146        self.cartpole.write_root_pose_to_sim(default_root_state[:, :7], env_ids)
147        self.cartpole.write_root_velocity_to_sim(default_root_state[:, 7:], env_ids)
148        self.cartpole.write_joint_state_to_sim(joint_pos, joint_vel, None, env_ids)
149
150
151@torch.jit.script
152def compute_rewards(
153    rew_scale_alive: float,
154    rew_scale_terminated: float,
155    rew_scale_pole_pos: float,
156    rew_scale_cart_vel: float,
157    rew_scale_pole_vel: float,
158    pole_pos: torch.Tensor,
159    pole_vel: torch.Tensor,
160    cart_pos: torch.Tensor,
161    cart_vel: torch.Tensor,
162    reset_terminated: torch.Tensor,
163):
164    rew_alive = rew_scale_alive * (1.0 - reset_terminated.float())
165    rew_termination = rew_scale_terminated * reset_terminated.float()
166    rew_pole_pos = rew_scale_pole_pos * torch.sum(torch.square(pole_pos).unsqueeze(dim=1), dim=-1)
167    rew_cart_vel = rew_scale_cart_vel * torch.sum(torch.abs(cart_vel).unsqueeze(dim=1), dim=-1)
168    rew_pole_vel = rew_scale_pole_vel * torch.sum(torch.abs(pole_vel).unsqueeze(dim=1), dim=-1)
169    total_reward = rew_alive + rew_termination + rew_pole_pos + rew_cart_vel + rew_pole_vel
170    return total_reward

The Code Explained#

Similar to the manager-based environments, a configuration class is defined for the task to hold settings for the simulation parameters, the scene, the actors, and the task. With the direct workflow implementation, the envs.DirectRLEnvCfg class is used as the base class for configurations. Since the direct workflow implementation does not use Action and Observation managers, the task config should define the number of actions and observations for the environment.

@configclass
class CartpoleEnvCfg(DirectRLEnvCfg):
   ...
   num_actions = 1
   num_observations = 4
   num_states = 0

The config class can also be used to define task-specific attributes, such as scaling for reward terms and thresholds for reset conditions.

@configclass
class CartpoleEnvCfg(DirectRLEnvCfg):
   ...
   # reset
   max_cart_pos = 3.0
   initial_pole_angle_range = [-0.25, 0.25]

   # reward scales
   rew_scale_alive = 1.0
   rew_scale_terminated = -2.0
   rew_scale_pole_pos = -1.0
   rew_scale_cart_vel = -0.01
   rew_scale_pole_vel = -0.005

When creating a new environment, the code should define a new class that inherits from DirectRLEnv.

class CartpoleEnv(DirectRLEnv):
   cfg: CartpoleEnvCfg

   def __init__(self, cfg: CartpoleEnvCfg, render_mode: str | None = None, **kwargs):
     super().__init__(cfg, render_mode, **kwargs)

The class can also hold class variables that are accessible by all functions in the class, including functions for applying actions, computing resets, rewards, and observations.

Scene Creation#

In contrast to manager-based environments where the scene creation is taken care of by the framework, the direct workflow implementation provides flexibility for users to implement their own scene creation function. This includes adding actors into the stage, cloning the environments, filtering collisions between the environments, adding the actors into the scene, and adding any additional props to the scene, such as ground plane and lights. These operations should be implemented in the _setup_scene(self) method.

    def _setup_scene(self):
        self.cartpole = Articulation(self.cfg.robot_cfg)
        # add ground plane
        spawn_ground_plane(prim_path="/World/ground", cfg=GroundPlaneCfg())
        # clone, filter, and replicate
        self.scene.clone_environments(copy_from_source=False)
        self.scene.filter_collisions(global_prim_paths=[])
        # add articultion to scene
        self.scene.articulations["cartpole"] = self.cartpole
        # add lights
        light_cfg = sim_utils.DomeLightCfg(intensity=2000.0, color=(0.75, 0.75, 0.75))
        light_cfg.func("/World/Light", light_cfg)

Defining Rewards#

Reward function should be defined in the _get_rewards(self) API, which returns the reward buffer as a return value. Within this function, the task is free to implement the logic of the reward function. In this example, we implement a Pytorch JIT function that computes the various components of the reward function.

def _get_rewards(self) -> torch.Tensor:
     total_reward = compute_rewards(
         self.cfg.rew_scale_alive,
         self.cfg.rew_scale_terminated,
         self.cfg.rew_scale_pole_pos,
         self.cfg.rew_scale_cart_vel,
         self.cfg.rew_scale_pole_vel,
         self.joint_pos[:, self._pole_dof_idx[0]],
         self.joint_vel[:, self._pole_dof_idx[0]],
         self.joint_pos[:, self._cart_dof_idx[0]],
         self.joint_vel[:, self._cart_dof_idx[0]],
         self.reset_terminated,
     )
     return total_reward

@torch.jit.script
def compute_rewards(
    rew_scale_alive: float,
    rew_scale_terminated: float,
    rew_scale_pole_pos: float,
    rew_scale_cart_vel: float,
    rew_scale_pole_vel: float,
    pole_pos: torch.Tensor,
    pole_vel: torch.Tensor,
    cart_pos: torch.Tensor,
    cart_vel: torch.Tensor,
    reset_terminated: torch.Tensor,
):
    rew_alive = rew_scale_alive * (1.0 - reset_terminated.float())
    rew_termination = rew_scale_terminated * reset_terminated.float()
    rew_pole_pos = rew_scale_pole_pos * torch.sum(torch.square(pole_pos), dim=-1)
    rew_cart_vel = rew_scale_cart_vel * torch.sum(torch.abs(cart_vel), dim=-1)
    rew_pole_vel = rew_scale_pole_vel * torch.sum(torch.abs(pole_vel), dim=-1)
    total_reward = rew_alive + rew_termination + rew_pole_pos + rew_cart_vel + rew_pole_vel
    return total_reward

Defining Observations#

The observation buffer should be computed in the _get_observations(self) function, which constructs the observation buffer for the environment. At the end of this API, a dictionary should be returned that contains policy as the key, and the full observation buffer as the value. For asymmetric policies, the dictionary should also include the key critic and the states buffer as the value.

    def _get_observations(self) -> dict:
        obs = torch.cat(
            (
                self.joint_pos[:, self._pole_dof_idx[0]].unsqueeze(dim=1),
                self.joint_vel[:, self._pole_dof_idx[0]].unsqueeze(dim=1),
                self.joint_pos[:, self._cart_dof_idx[0]].unsqueeze(dim=1),
                self.joint_vel[:, self._cart_dof_idx[0]].unsqueeze(dim=1),
            ),
            dim=-1,
        )
        observations = {"policy": obs}
        return observations

Computing Dones and Performing Resets#

Populating the dones buffer should be done in the _get_dones(self) method. This method is free to implement logic that computes which environments would need to be reset and which environments have reached the episode length limit. Both results should be returned by the _get_dones(self) function, in the form of a tuple of boolean tensors.

    def _get_dones(self) -> tuple[torch.Tensor, torch.Tensor]:
        self.joint_pos = self.cartpole.data.joint_pos
        self.joint_vel = self.cartpole.data.joint_vel

        time_out = self.episode_length_buf >= self.max_episode_length - 1
        out_of_bounds = torch.any(torch.abs(self.joint_pos[:, self._cart_dof_idx]) > self.cfg.max_cart_pos, dim=1)
        out_of_bounds = out_of_bounds | torch.any(torch.abs(self.joint_pos[:, self._pole_dof_idx]) > math.pi / 2, dim=1)
        return out_of_bounds, time_out

Once the indices for environments requiring reset have been computed, the _reset_idx(self, env_ids) function performs the reset operations on those environments. Within this function, new states for the environments requiring reset should be set directly into simulation.

    def _reset_idx(self, env_ids: Sequence[int] | None):
        if env_ids is None:
            env_ids = self.cartpole._ALL_INDICES
        super()._reset_idx(env_ids)

        joint_pos = self.cartpole.data.default_joint_pos[env_ids]
        joint_pos[:, self._pole_dof_idx] += sample_uniform(
            self.cfg.initial_pole_angle_range[0] * math.pi,
            self.cfg.initial_pole_angle_range[1] * math.pi,
            joint_pos[:, self._pole_dof_idx].shape,
            joint_pos.device,
        )
        joint_vel = self.cartpole.data.default_joint_vel[env_ids]

        default_root_state = self.cartpole.data.default_root_state[env_ids]
        default_root_state[:, :3] += self.scene.env_origins[env_ids]

        self.joint_pos[env_ids] = joint_pos
        self.joint_vel[env_ids] = joint_vel

        self.cartpole.write_root_pose_to_sim(default_root_state[:, :7], env_ids)
        self.cartpole.write_root_velocity_to_sim(default_root_state[:, 7:], env_ids)
        self.cartpole.write_joint_state_to_sim(joint_pos, joint_vel, None, env_ids)

Applying Actions#

There are two APIs that are designed for working with actions. The _pre_physics_step(self, actions) takes in actions from the policy as an argument and is called once per RL step, prior to taking any physics steps. This function can be used to process the actions buffer from the policy and cache the data in a class variable for the environment.

    def _pre_physics_step(self, actions: torch.Tensor) -> None:
        self.actions = self.action_scale * actions.clone()

The _apply_action(self) API is called decimation number of times for each RL step, prior to taking each physics step. This provides more flexibility for environments where actions should be applied for each physics step.

    def _apply_action(self) -> None:
        self.cartpole.set_joint_effort_target(self.actions, joint_ids=self._cart_dof_idx)

The Code Execution#

To run training for the direct workflow Cartpole environment, we can use the following command:

./isaaclab.sh -p source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-Direct-v0

All direct workflow tasks have the suffix -Direct added to the task name to differentiate the implementation style.

Domain Randomization#

In the direct workflow, domain randomization configuration uses the configclass module to specify a configuration class consisting of EventTermCfg variables.

Below is an example of a configuration class for domain randomization:

@configclass
class EventCfg:
  robot_physics_material = EventTerm(
      func=mdp.randomize_rigid_body_material,
      mode="reset",
      params={
          "asset_cfg": SceneEntityCfg("robot", body_names=".*"),
          "static_friction_range": (0.7, 1.3),
          "dynamic_friction_range": (1.0, 1.0),
          "restitution_range": (1.0, 1.0),
          "num_buckets": 250,
      },
  )
  robot_joint_stiffness_and_damping = EventTerm(
      func=mdp.randomize_actuator_gains,
      mode="reset",
      params={
          "asset_cfg": SceneEntityCfg("robot", joint_names=".*"),
          "stiffness_distribution_params": (0.75, 1.5),
          "damping_distribution_params": (0.3, 3.0),
          "operation": "scale",
          "distribution": "log_uniform",
      },
  )
  reset_gravity = EventTerm(
      func=mdp.randomize_physics_scene_gravity,
      mode="interval",
      is_global_time=True,
      interval_range_s=(36.0, 36.0),  # time_s = num_steps * (decimation * dt)
      params={
          "gravity_distribution_params": ([0.0, 0.0, 0.0], [0.0, 0.0, 0.4]),
          "operation": "add",
          "distribution": "gaussian",
      },
  )

Each EventTerm object is of the EventTermCfg class and takes in a func parameter for specifying the function to call during randomization, a mode parameter, which can be startup, reset or interval. THe params dictionary should provide the necessary arguments to the function that is specified in the func parameter. Functions specified as func for the EventTerm can be found in the events module.

Note that as part of the "asset_cfg": SceneEntityCfg("robot", body_names=".*") parameter, the name of the actor "robot" is provided, along with the body or joint names specified as a regex expression, which will be the actors and bodies/joints that will have randomization applied.

Once the configclass for the randomization terms have been set up, the class must be added to the base config class for the task and be assigned to the variable events.

@configclass
class MyTaskConfig:
  events: EventCfg = EventCfg()

Action and Observation Noise#

Actions and observation noise can also be added using the configclass module. Action and observation noise configs must be added to the main task config using the action_noise_model and observation_noise_model variables:

@configclass
class MyTaskConfig:

    # at every time-step add gaussian noise + bias. The bias is a gaussian sampled at reset
    action_noise_model: NoiseModelWithAdditiveBiasCfg = NoiseModelWithAdditiveBiasCfg(
      noise_cfg=GaussianNoiseCfg(mean=0.0, std=0.05, operation="add"),
      bias_noise_cfg=GaussianNoiseCfg(mean=0.0, std=0.015, operation="abs"),
    )

    # at every time-step add gaussian noise + bias. The bias is a gaussian sampled at reset
    observation_noise_model: NoiseModelWithAdditiveBiasCfg = NoiseModelWithAdditiveBiasCfg(
      noise_cfg=GaussianNoiseCfg(mean=0.0, std=0.002, operation="add"),
      bias_noise_cfg=GaussianNoiseCfg(mean=0.0, std=0.0001, operation="abs"),
    )

NoiseModelWithAdditiveBiasCfg can be used to sample both uncorrelated noise per step as well as correlated noise that is re-sampled at reset time.

The noise_cfg term specifies the Gaussian distribution that will be sampled at each step for all environments. This noise will be added to the corresponding actions and observations buffers at every step.

The bias_noise_cfg term specifies the Gaussian distribution for the correlated noise that will be sampled at reset time for the environments being reset. The same noise will be applied each step for the remaining of the episode for the environments and resampled at the next reset.

If only per-step noise is desired, GaussianNoiseCfg can be used to specify an additive Gaussian distribution that adds the sampled noise to the input buffer.

@configclass
class MyTaskConfig:
  action_noise_model: GaussianNoiseCfg = GaussianNoiseCfg(mean=0.0, std=0.05, operation="add")

In this tutorial, we learnt how to create a direct workflow task environment for reinforcement learning. We do this by extending the base environment to include the scene setup, actions, dones, reset, reward and observaion functions.

While it is possible to manually create an instance of DirectRLEnv class for a desired task, this is not scalable as it requires specialized scripts for each task. Thus, we exploit the gymnasium.make() function to create the environment with the gym interface. We will learn how to do this in the next tutorial.