Source code for isaaclab.envs.mdp.rewards

# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Common functions that can be used to enable reward functions.

The functions can be passed to the :class:`isaaclab.managers.RewardTermCfg` object to include
the reward introduced by the function.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from isaaclab.managers import SceneEntityCfg
from isaaclab.managers.manager_base import ManagerTermBase
from isaaclab.managers.manager_term_cfg import RewardTermCfg
from isaaclab.utils.math import combine_frame_transforms, quat_error_magnitude, quat_mul

if TYPE_CHECKING:
    from isaaclab.assets import Articulation, RigidObject
    from isaaclab.envs import ManagerBasedRLEnv
    from isaaclab.sensors import ContactSensor, RayCaster

"""
General.
"""


[docs] def is_alive(env: ManagerBasedRLEnv) -> torch.Tensor: """Reward for being alive.""" return (~env.termination_manager.terminated).float()
[docs] def is_terminated(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize terminated episodes that don't correspond to episodic timeouts.""" return env.termination_manager.terminated.float()
[docs] class is_terminated_term(ManagerTermBase): """Penalize termination for specific terms that don't correspond to episodic timeouts. The parameters are as follows: * attr:`term_keys`: The termination terms to penalize. This can be a string, a list of strings or regular expressions. Default is ".*" which penalizes all terminations. The reward is computed as the sum of the termination terms that are not episodic timeouts. This means that the reward is 0 if the episode is terminated due to an episodic timeout. Otherwise, if two termination terms are active, the reward is 2. """
[docs] def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): # initialize the base class super().__init__(cfg, env) # find and store the termination terms term_keys = cfg.params.get("term_keys", ".*") self._term_names = env.termination_manager.find_terms(term_keys)
def __call__(self, env: ManagerBasedRLEnv, term_keys: str | list[str] = ".*") -> torch.Tensor: # Return the unweighted reward for the termination terms reset_buf = torch.zeros(env.num_envs, device=env.device) for term in self._term_names: # Sums over terminations term values to account for multiple terminations in the same step reset_buf += env.termination_manager.get_term(term) return (reset_buf * (~env.termination_manager.time_outs)).float()
""" Root penalties. """
[docs] def lin_vel_z_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize z-axis base linear velocity using L2 squared kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] return torch.square(asset.data.root_lin_vel_b.torch[:, 2])
[docs] def ang_vel_xy_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize xy-axis base angular velocity using L2 squared kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.root_ang_vel_b.torch[:, :2]), dim=1)
[docs] def flat_orientation_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize non-flat base orientation using L2 squared kernel. This is computed by penalizing the xy-components of the projected gravity vector. """ # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.projected_gravity_b.torch[:, :2]), dim=1)
[docs] def base_height_l2( env: ManagerBasedRLEnv, target_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), sensor_cfg: SceneEntityCfg | None = None, ) -> torch.Tensor: """Penalize asset height from its target using L2 squared kernel. Note: For flat terrain, target height is in the world frame. For rough terrain, sensor readings can adjust the target height to account for the terrain. """ # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] if sensor_cfg is not None: sensor: RayCaster = env.scene[sensor_cfg.name] # Adjust the target height using the sensor data adjusted_target_height = target_height + torch.mean(sensor.data.ray_hits_w.torch[..., 2], dim=1) else: # Use the provided target height directly for flat terrain adjusted_target_height = target_height # Compute the L2 squared penalty return torch.square(asset.data.root_pos_w.torch[:, 2] - adjusted_target_height)
[docs] def body_lin_acc_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize the linear acceleration of bodies using L2-kernel.""" asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.linalg.norm(asset.data.body_lin_acc_w.torch[:, asset_cfg.body_ids, :], dim=-1), dim=1)
""" Joint penalties. """
[docs] def joint_torques_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint torques applied on the articulation using L2 squared kernel. .. note:: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint torques contribute to the term. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.applied_torque.torch[:, asset_cfg.joint_ids]), dim=1)
[docs] def joint_vel_l1(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize joint velocities on the articulation using an L1-kernel.""" # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.abs(asset.data.joint_vel.torch[:, asset_cfg.joint_ids]), dim=1)
[docs] def joint_vel_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint velocities on the articulation using L2 squared kernel. .. note:: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint velocities contribute to the term. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.joint_vel.torch[:, asset_cfg.joint_ids]), dim=1)
[docs] def joint_acc_l2(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint accelerations on the articulation using L2 squared kernel. .. note:: Only the joints configured in :attr:`asset_cfg.joint_ids` will have their joint accelerations contribute to the term. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] return torch.sum(torch.square(asset.data.joint_acc.torch[:, asset_cfg.joint_ids]), dim=1)
[docs] def joint_deviation_l1(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint positions that deviate from the default one.""" # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints angle = ( asset.data.joint_pos.torch[:, asset_cfg.joint_ids] - asset.data.default_joint_pos.torch[:, asset_cfg.joint_ids] ) return torch.sum(torch.abs(angle), dim=1)
[docs] def joint_pos_limits(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize joint positions if they cross the soft limits. This is computed as a sum of the absolute value of the difference between the joint position and the soft limits. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints out_of_limits = -( asset.data.joint_pos.torch[:, asset_cfg.joint_ids] - asset.data.soft_joint_pos_limits.torch[:, asset_cfg.joint_ids, 0] ).clip(max=0.0) out_of_limits += ( asset.data.joint_pos.torch[:, asset_cfg.joint_ids] - asset.data.soft_joint_pos_limits.torch[:, asset_cfg.joint_ids, 1] ).clip(min=0.0) return torch.sum(out_of_limits, dim=1)
[docs] def joint_vel_limits( env: ManagerBasedRLEnv, soft_ratio: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") ) -> torch.Tensor: """Penalize joint velocities if they cross the soft limits. This is computed as a sum of the absolute value of the difference between the joint velocity and the soft limits. Args: soft_ratio: The ratio of the soft limits to be used. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints out_of_limits = ( torch.abs(asset.data.joint_vel.torch[:, asset_cfg.joint_ids]) - asset.data.soft_joint_vel_limits.torch[:, asset_cfg.joint_ids] * soft_ratio ) # clip to max error = 1 rad/s per joint to avoid huge penalties out_of_limits = out_of_limits.clip_(min=0.0, max=1.0) return torch.sum(out_of_limits, dim=1)
""" Action penalties. """
[docs] def applied_torque_limits(env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor: """Penalize applied torques if they cross the limits. This is computed as a sum of the absolute value of the difference between the applied torques and the limits. .. caution:: Currently, this only works for explicit actuators since we manually compute the applied torques. For implicit actuators, we currently cannot retrieve the applied torques from the physics engine. """ # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints # TODO: We need to fix this to support implicit joints. out_of_limits = torch.abs( asset.data.applied_torque.torch[:, asset_cfg.joint_ids] - asset.data.computed_torque.torch[:, asset_cfg.joint_ids] ) return torch.sum(out_of_limits, dim=1)
[docs] def action_rate_l2(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize the rate of change of the actions using L2 squared kernel.""" return torch.sum(torch.square(env.action_manager.action - env.action_manager.prev_action), dim=1)
[docs] def action_l2(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize the actions using L2 squared kernel.""" return torch.sum(torch.square(env.action_manager.action), dim=1)
""" Contact sensor. """
[docs] def undesired_contacts(env: ManagerBasedRLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize undesired contacts as the number of violations that are above a threshold.""" # extract the used quantities (to enable type-hinting) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] # check if contact force is above threshold net_contact_forces = contact_sensor.data.net_forces_w_history.torch is_contact = ( torch.max(torch.linalg.norm(net_contact_forces[:, :, sensor_cfg.body_ids], dim=-1), dim=1)[0] > threshold ) # sum over contacts for each environment return torch.sum(is_contact, dim=1)
[docs] def desired_contacts(env, sensor_cfg: SceneEntityCfg, threshold: float = 1.0) -> torch.Tensor: """Penalize if none of the desired contacts are present.""" contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] contacts = ( contact_sensor.data.net_forces_w_history.torch[:, :, sensor_cfg.body_ids, :].norm(dim=-1).max(dim=1)[0] > threshold ) zero_contact = (~contacts).all(dim=1) return 1.0 * zero_contact
[docs] def contact_forces(env: ManagerBasedRLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize contact forces as the amount of violations of the net contact force.""" # extract the used quantities (to enable type-hinting) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] net_contact_forces = contact_sensor.data.net_forces_w_history.torch # compute the violation violation = ( torch.max(torch.linalg.norm(net_contact_forces[:, :, sensor_cfg.body_ids], dim=-1), dim=1)[0] - threshold ) # compute the penalty return torch.sum(violation.clip(min=0.0), dim=1)
""" Velocity-tracking rewards. """
[docs] def track_lin_vel_xy_exp( env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") ) -> torch.Tensor: """Reward tracking of linear velocity commands (xy axes) using exponential kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] # compute the error lin_vel_error = torch.sum( torch.square(env.command_manager.get_command(command_name)[:, :2] - asset.data.root_lin_vel_b.torch[:, :2]), dim=1, ) return torch.exp(-lin_vel_error / std**2)
[docs] def track_ang_vel_z_exp( env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") ) -> torch.Tensor: """Reward tracking of angular velocity commands (yaw) using exponential kernel.""" # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] # compute the error ang_vel_error = torch.square( env.command_manager.get_command(command_name)[:, 2] - asset.data.root_ang_vel_b.torch[:, 2] ) return torch.exp(-ang_vel_error / std**2)
""" Pose-tracking rewards. """
[docs] def position_command_error(env: ManagerBasedRLEnv, command_name: str, asset_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize tracking of the position error using the L2 norm. The error [m] is the L2 norm between the commanded position (resolved into the world frame from the asset's root pose) and the current position of the asset's body in the world frame. The command is expected to be a pose command whose first three entries are the desired position in the root frame. """ # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] command = env.command_manager.get_command(command_name) # obtain the desired and current positions in the world frame des_pos_b = command[:, :3] des_pos_w, _ = combine_frame_transforms(asset.data.root_pos_w.torch, asset.data.root_quat_w.torch, des_pos_b) curr_pos_w = asset.data.body_pos_w.torch[:, asset_cfg.body_ids[0]] # type: ignore return torch.linalg.norm(curr_pos_w - des_pos_w, dim=1)
[docs] def position_command_error_tanh( env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg ) -> torch.Tensor: """Reward tracking of the position error using the tanh kernel. The position error [m] is computed as in :func:`position_command_error` and mapped through a tanh kernel with standard deviation ``std`` [m], yielding a bounded reward in ``[0, 1)``. """ distance = position_command_error(env, command_name, asset_cfg) return 1 - torch.tanh(distance / std)
[docs] def orientation_command_error(env: ManagerBasedRLEnv, command_name: str, asset_cfg: SceneEntityCfg) -> torch.Tensor: """Penalize tracking of the orientation error using the shortest-path quaternion distance. The error [rad] is the shortest-path angle between the commanded orientation (resolved into the world frame from the asset's root pose) and the current orientation of the asset's body in the world frame. The command is expected to be a pose command whose entries ``[3:7]`` are the desired orientation quaternion in the root frame. """ # extract the used quantities (to enable type-hinting) asset: RigidObject = env.scene[asset_cfg.name] command = env.command_manager.get_command(command_name) # obtain the desired and current orientations in the world frame des_quat_b = command[:, 3:7] des_quat_w = quat_mul(asset.data.root_quat_w.torch, des_quat_b) curr_quat_w = asset.data.body_quat_w.torch[:, asset_cfg.body_ids[0]] # type: ignore return quat_error_magnitude(curr_quat_w, des_quat_w)