Source code for isaaclab_contrib.controllers.lee_controller_base

# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Base class for Lee-style geometric controllers."""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
import warp as wp

import isaaclab.sim as sim_utils
import isaaclab.utils.math as math_utils

from isaaclab_contrib.utils.math import aggregate_inertia_about_robot_com

if TYPE_CHECKING:
    from isaaclab.assets import Multirotor

    from .lee_controller_base_cfg import LeeControllerBaseCfg


[docs] class LeeControllerBase: """Base class for Lee-style geometric controllers.""" cfg: LeeControllerBaseCfg device: str robot: Multirotor
[docs] def __init__(self, cfg: LeeControllerBaseCfg, asset: Multirotor, num_envs: int, device: str): """Initialize controller buffers and pre-compute aggregate inertias. Args: cfg: Controller configuration. asset: Multirotor asset to control. num_envs: Number of environments. device: Device to run computations on. """ self.cfg = cfg self.robot = asset self.device = device self.num_envs = num_envs root_quat_w = self._to_torch(self.robot.data.root_link_quat_w) body_link_pos_w = self._to_torch(self.robot.data.body_link_pos_w) root_pos_w = self._to_torch(self.robot.data.root_pos_w) body_com_pos_b = self._to_torch(self.robot.data.body_com_pos_b) body_com_quat_b = self._to_torch(self.robot.data.body_com_quat_b) body_link_quat_w = self._to_torch(self.robot.data.body_link_quat_w) # Aggregate mass and inertia about the robot COM for all bodies root_quat_exp = root_quat_w.unsqueeze(1).expand(num_envs, self.robot.num_bodies, 4) body_link_pos_delta = body_link_pos_w - root_pos_w.unsqueeze(1) body_masses = self._to_torch(self.robot.root_view.get_masses()) body_inv_mass_local = torch.where(body_masses > 0, 1.0 / body_masses, torch.zeros_like(body_masses)) self.mass, self.robot_inertia, _ = aggregate_inertia_about_robot_com( self._to_torch(self.robot.root_view.get_inertias()), body_inv_mass_local, body_com_pos_b, body_com_quat_b, math_utils.quat_apply_inverse(root_quat_exp, body_link_pos_delta), math_utils.quat_mul(math_utils.quat_inv(root_quat_exp), body_link_quat_w), ) # Get gravity from simulation context sim = sim_utils.SimulationContext.instance() gravity_vec = sim.cfg.gravity self.gravity = torch.tensor(gravity_vec, device=device, dtype=torch.float32).expand(num_envs, -1) # Buffers self.wrench_command_b = torch.zeros((num_envs, 6), device=device) # [fx, fy, fz, tx, ty, tz] self.rotation_matrix_buffer = torch.zeros((num_envs, 3, 3), device=device)
def _to_torch(self, x): """Convert warp array to torch tensor on controller device; no-op for torch tensors.""" if torch.is_tensor(x): return x.to(self.device) return wp.to_torch(x).to(self.device) def _root_state_tensors(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Fetch root state once per control step.""" root_quat_w = self._to_torch(self.robot.data.root_quat_w) root_ang_vel_b = self._to_torch(self.robot.data.root_ang_vel_b) root_lin_vel_w = self._to_torch(self.robot.data.root_lin_vel_w) return root_quat_w, root_ang_vel_b, root_lin_vel_w
[docs] def reset(self): """Reset controller state for all environments.""" self.reset_idx(env_ids=None)
[docs] def reset_idx(self, env_ids: torch.Tensor | None): """Reset controller state (and optionally randomize gains) for selected environments. Args: env_ids: Tensor of environment indices, or ``None`` for all. """ if env_ids is None: env_ids = slice(None) self._randomize_params(env_ids)
def _randomize_params(self, env_ids: slice | torch.Tensor): """Randomize controller gains for the given environments if enabled. Override in subclass to implement parameter randomization. """ pass
[docs] def compute(self, command: torch.Tensor) -> torch.Tensor: """Compute wrench command from input command. Args: command: Input command (shape depends on controller type). Returns: (num_envs, 6) wrench command [fx, fy, fz, tx, ty, tz] in body frame. """ raise NotImplementedError("Subclasses must implement compute()")