Source code for omni.isaac.lab.envs.mdp.rewards

# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Common functions that can be used to enable reward functions.

The functions can be passed to the :class:`omni.isaac.lab.managers.RewardTermCfg` object to include
the reward introduced by the function.
"""

from __future__ import annotations

import torch
from typing import TYPE_CHECKING

from omni.isaac.lab.assets import Articulation, RigidObject
from omni.isaac.lab.managers import SceneEntityCfg
from omni.isaac.lab.managers.manager_base import ManagerTermBase
from omni.isaac.lab.managers.manager_term_cfg import RewardTermCfg
from omni.isaac.lab.sensors import ContactSensor, RayCaster

if TYPE_CHECKING:
    from omni.isaac.lab.envs import ManagerBasedRLEnv

"""
General.
"""


[docs]def is_alive(env: ManagerBasedRLEnv) -> torch.Tensor: """Reward for being alive.""" return (~env.termination_manager.terminated).float()
[docs]def is_terminated(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize terminated episodes that don't correspond to episodic timeouts.""" return env.termination_manager.terminated.float()
[docs]class is_terminated_term(ManagerTermBase): """Penalize termination for specific terms that don't correspond to episodic timeouts. The parameters are as follows: * attr:`term_keys`: The termination terms to penalize. This can be a string, a list of strings or regular expressions. Default is ".*" which penalizes all terminations. The reward is computed as the sum of the termination terms that are not episodic timeouts. This means that the reward is 0 if the episode is terminated due to an episodic timeout. Otherwise, if two termination terms are active, the reward is 2. """
[docs] def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): # initialize the base class super().__init__(cfg, env) # find and store the termination terms term_keys = cfg.params.get("term_keys", ".*") self._term_names = env.termination_manager.find_terms(term_keys)
def __call__(self, env: ManagerBasedRLEnv, term_keys: str | list[str] = ".*") -> torch.Tensor: # Return the unweighted reward for the termination terms reset_buf = torch.zeros(env.num_envs, device=env.device) for term in self._term_names: # Sums over terminations term values to account for multiple terminations in the same step reset_buf += env.termination_manager.get_term(term) return (reset_buf * (~env.termination_manager.time_outs)).float()
""" Root penalties. """
[docs]def lin_vel_z_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize z-axis base linear velocity using L2 squared kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] return torch.square(asset.data.root_com_lin_vel_b[:, 2])
[docs]def ang_vel_xy_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize xy-axis base angular velocity using L2 squared kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.root_com_ang_vel_b[:, :2]), dim=1)
[docs]def flat_orientation_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize non-flat base orientation using L2 squared kernel. This is computed by penalizing the xy-components of the projected gravity vector. """ # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.projected_gravity_b[:, :2]), dim=1)
[docs]def base_height_l2( env: ManagerBasedRLEnv, target_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), sensor_cfg: SceneEntityCfg | None = None, ) -> torch.Tensor: """Penalize asset height from its target using L2 squared kernel. Note: For flat terrain, target height is in the world frame. For rough terrain, sensor readings can adjust the target height to account for the terrain. """ # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] if sensor_cfg is not None: sensor: RayCaster = env.scene[sensor_cfg.name] # Adjust the target height using the sensor data adjusted_target_height = target_height + sensor.data.pos_w[:, 2] else: # Use the provided target height directly for flat terrain adjusted_target_height = target_height # Compute the L2 squared penalty return torch.square(asset.data.root_link_pos_w[:, 2] - adjusted_target_height)
[docs]def body_lin_acc_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize the linear acceleration of bodies using L2-kernel.""" asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.norm(asset.data.body_lin_acc_w[:, asset_cfg.body_ids, :], dim=-1), dim=1)
""" Joint penalties. """
[docs]def joint_torques_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint torques applied on the articulation using L2 squared kernel. NOTE: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint torques contribute to the term. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.applied_torque[:, asset_cfg.joint_ids]), dim=1)
[docs]def joint_vel_l1(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize joint velocities on the articulation using an L1-kernel.""" # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.abs(asset.data.joint_vel[:, asset_cfg.joint_ids]), dim=1)
[docs]def joint_vel_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint velocities on the articulation using L2 squared kernel. NOTE: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint velocities contribute to the term. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.joint_vel[:, asset_cfg.joint_ids]), dim=1)
[docs]def joint_acc_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint accelerations on the articulation using L2 squared kernel. NOTE: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint accelerations contribute to the term. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.joint_acc[:, asset_cfg.joint_ids]), dim=1)
[docs]def joint_deviation_l1(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint positions that deviate from the default one.""" # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints angle = asset.data.joint_pos[:, asset_cfg.joint_ids] - asset.data.default_joint_pos[:, asset_cfg.joint_ids] return torch.sum(torch.abs(angle), dim=1)
[docs]def joint_pos_limits(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint positions if they cross the soft limits. This is computed as a sum of the absolute value of the difference between the joint position and the soft limits. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints out_of_limits = -( asset.data.joint_pos[:, asset_cfg.joint_ids] - asset.data.soft_joint_pos_limits[:, asset_cfg.joint_ids, 0] ).clip(max=0.0) out_of_limits += ( asset.data.joint_pos[:, asset_cfg.joint_ids] - asset.data.soft_joint_pos_limits[:, asset_cfg.joint_ids, 1] ).clip(min=0.0) return torch.sum(out_of_limits, dim=1)
[docs]def joint_vel_limits( env: ManagerBasedRLEnv, soft_ratio: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") ) -> torch.Tensor: """Penalize joint velocities if they cross the soft limits. This is computed as a sum of the absolute value of the difference between the joint velocity and the soft limits. Args: soft_ratio: The ratio of the soft limits to be used. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints out_of_limits = ( torch.abs(asset.data.joint_vel[:, asset_cfg.joint_ids]) - asset.data.soft_joint_vel_limits[:, asset_cfg.joint_ids] * soft_ratio ) # clip to max error = 1 rad/s per joint to avoid huge penalties out_of_limits = out_of_limits.clip_(min=0.0, max=1.0) return torch.sum(out_of_limits, dim=1)
""" Action penalties. """
[docs]def applied_torque_limits(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize applied torques if they cross the limits. This is computed as a sum of the absolute value of the difference between the applied torques and the limits. .. caution:: Currently, this only works for explicit actuators since we manually compute the applied torques. For implicit actuators, we currently cannot retrieve the applied torques from the physics engine. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints # TODO: We need to fix this to support implicit joints. out_of_limits = torch.abs( asset.data.applied_torque[:, asset_cfg.joint_ids] - asset.data.computed_torque[:, asset_cfg.joint_ids] ) return torch.sum(out_of_limits, dim=1)
[docs]def action_rate_l2(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize the rate of change of the actions using L2 squared kernel.""" return torch.sum(torch.square(env.action_manager.action - env.action_manager.prev_action), dim=1)
[docs]def action_l2(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize the actions using L2 squared kernel.""" return torch.sum(torch.square(env.action_manager.action), dim=1)
""" Contact sensor. """
[docs]def undesired_contacts(env: ManagerBasedRLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize undesired contacts as the number of violations that are above a threshold.""" # extract the used quantities (to enable type-hinting) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] # check if contact force is above threshold net_contact_forces = contact_sensor.data.net_forces_w_history is_contact = torch.max(torch.norm(net_contact_forces[:, :, sensor_cfg.body_ids], dim=-1), dim=1)[0] > threshold # sum over contacts for each environment return torch.sum(is_contact, dim=1)
[docs]def contact_forces(env: ManagerBasedRLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize contact forces as the amount of violations of the net contact force.""" # extract the used quantities (to enable type-hinting) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] net_contact_forces = contact_sensor.data.net_forces_w_history # compute the violation violation = torch.max(torch.norm(net_contact_forces[:, :, sensor_cfg.body_ids], dim=-1), dim=1)[0] - threshold # compute the penalty return torch.sum(violation.clip(min=0.0), dim=1)
""" Velocity-tracking rewards. """
[docs]def track_lin_vel_xy_exp( env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") ) -> torch.Tensor: """Reward tracking of linear velocity commands (xy axes) using exponential kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] # compute the error lin_vel_error = torch.sum( torch.square(env.command_manager.get_command(command_name)[:, :2] - asset.data.root_com_lin_vel_b[:, :2]), dim=1, ) return torch.exp(-lin_vel_error / std**2)
[docs]def track_ang_vel_z_exp( env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") ) -> torch.Tensor: """Reward tracking of angular velocity commands (yaw) using exponential kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] # compute the error ang_vel_error = torch.square( env.command_manager.get_command(command_name)[:, 2] - asset.data.root_com_ang_vel_b[:, 2] ) return torch.exp(-ang_vel_error / std**2)