Source code for omni.isaac.lab.actuators.actuator_net

# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Neural network models for actuators.

Currently, the following models are supported:

* Multi-Layer Perceptron (MLP)
* Long Short-Term Memory (LSTM)

"""

from __future__ import annotations

import torch
from collections.abc import Sequence
from typing import TYPE_CHECKING

from omni.isaac.lab.utils.assets import read_file
from omni.isaac.lab.utils.types import ArticulationActions

from .actuator_pd import DCMotor

if TYPE_CHECKING:
    from .actuator_cfg import ActuatorNetLSTMCfg, ActuatorNetMLPCfg


[docs]class ActuatorNetLSTM(DCMotor): """Actuator model based on recurrent neural network (LSTM). Unlike the MLP implementation :cite:t:`hwangbo2019learning`, this class implements the learned model as a temporal neural network (LSTM) based on the work from :cite:t:`rudin2022learning`. This removes the need of storing a history as the hidden states of the recurrent network captures the history. Note: Only the desired joint positions are used as inputs to the network. """ cfg: ActuatorNetLSTMCfg """The configuration of the actuator model."""
[docs] def __init__(self, cfg: ActuatorNetLSTMCfg, *args, **kwargs): super().__init__(cfg, *args, **kwargs) # load the model from JIT file file_bytes = read_file(self.cfg.network_file) self.network = torch.jit.load(file_bytes, map_location=self._device) # extract number of lstm layers and hidden dim from the shape of weights num_layers = len(self.network.lstm.state_dict()) // 4 hidden_dim = self.network.lstm.state_dict()["weight_hh_l0"].shape[1] # create buffers for storing LSTM inputs self.sea_input = torch.zeros(self._num_envs * self.num_joints, 1, 2, device=self._device) self.sea_hidden_state = torch.zeros( num_layers, self._num_envs * self.num_joints, hidden_dim, device=self._device ) self.sea_cell_state = torch.zeros(num_layers, self._num_envs * self.num_joints, hidden_dim, device=self._device) # reshape via views (doesn't change the actual memory layout) layer_shape_per_env = (num_layers, self._num_envs, self.num_joints, hidden_dim) self.sea_hidden_state_per_env = self.sea_hidden_state.view(layer_shape_per_env) self.sea_cell_state_per_env = self.sea_cell_state.view(layer_shape_per_env)
""" Operations. """
[docs] def reset(self, env_ids: Sequence[int]): # reset the hidden and cell states for the specified environments with torch.no_grad(): self.sea_hidden_state_per_env[:, env_ids] = 0.0 self.sea_cell_state_per_env[:, env_ids] = 0.0
[docs] def compute( self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor ) -> ArticulationActions: # compute network inputs self.sea_input[:, 0, 0] = (control_action.joint_positions - joint_pos).flatten() self.sea_input[:, 0, 1] = joint_vel.flatten() # save current joint vel for dc-motor clipping self._joint_vel[:] = joint_vel # run network inference with torch.inference_mode(): torques, (self.sea_hidden_state[:], self.sea_cell_state[:]) = self.network( self.sea_input, (self.sea_hidden_state, self.sea_cell_state) ) self.computed_effort = torques.reshape(self._num_envs, self.num_joints) # clip the computed effort based on the motor limits self.applied_effort = self._clip_effort(self.computed_effort) # return torques control_action.joint_efforts = self.applied_effort control_action.joint_positions = None control_action.joint_velocities = None return control_action
[docs]class ActuatorNetMLP(DCMotor): """Actuator model based on multi-layer perceptron and joint history. Many times the analytical model is not sufficient to capture the actuator dynamics, the delay in the actuator response, or the non-linearities in the actuator. In these cases, a neural network model can be used to approximate the actuator dynamics. This model is trained using data collected from the physical actuator and maps the joint state and the desired joint command to the produced torque by the actuator. This class implements the learned model as a neural network based on the work from :cite:t:`hwangbo2019learning`. The class stores the history of the joint positions errors and velocities which are used to provide input to the neural network. The model is loaded as a TorchScript. Note: Only the desired joint positions are used as inputs to the network. """ cfg: ActuatorNetMLPCfg """The configuration of the actuator model."""
[docs] def __init__(self, cfg: ActuatorNetMLPCfg, *args, **kwargs): super().__init__(cfg, *args, **kwargs) # load the model from JIT file file_bytes = read_file(self.cfg.network_file) self.network = torch.jit.load(file_bytes, map_location=self._device) # create buffers for MLP history history_length = max(self.cfg.input_idx) + 1 self._joint_pos_error_history = torch.zeros( self._num_envs, history_length, self.num_joints, device=self._device ) self._joint_vel_history = torch.zeros(self._num_envs, history_length, self.num_joints, device=self._device)
""" Operations. """
[docs] def reset(self, env_ids: Sequence[int]): # reset the history for the specified environments self._joint_pos_error_history[env_ids] = 0.0 self._joint_vel_history[env_ids] = 0.0
[docs] def compute( self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor ) -> ArticulationActions: # move history queue by 1 and update top of history # -- positions self._joint_pos_error_history = self._joint_pos_error_history.roll(1, 1) self._joint_pos_error_history[:, 0] = control_action.joint_positions - joint_pos # -- velocity self._joint_vel_history = self._joint_vel_history.roll(1, 1) self._joint_vel_history[:, 0] = joint_vel # save current joint vel for dc-motor clipping self._joint_vel[:] = joint_vel # compute network inputs # -- positions pos_input = torch.cat([self._joint_pos_error_history[:, i].unsqueeze(2) for i in self.cfg.input_idx], dim=2) pos_input = pos_input.view(self._num_envs * self.num_joints, -1) # -- velocity vel_input = torch.cat([self._joint_vel_history[:, i].unsqueeze(2) for i in self.cfg.input_idx], dim=2) vel_input = vel_input.view(self._num_envs * self.num_joints, -1) # -- scale and concatenate inputs if self.cfg.input_order == "pos_vel": network_input = torch.cat([pos_input * self.cfg.pos_scale, vel_input * self.cfg.vel_scale], dim=1) elif self.cfg.input_order == "vel_pos": network_input = torch.cat([vel_input * self.cfg.vel_scale, pos_input * self.cfg.pos_scale], dim=1) else: raise ValueError( f"Invalid input order for MLP actuator net: {self.cfg.input_order}. Must be 'pos_vel' or 'vel_pos'." ) # run network inference torques = self.network(network_input).view(self._num_envs, self.num_joints) self.computed_effort = torques.view(self._num_envs, self.num_joints) * self.cfg.torque_scale # clip the computed effort based on the motor limits self.applied_effort = self._clip_effort(self.computed_effort) # return torques control_action.joint_efforts = self.applied_effort control_action.joint_positions = None control_action.joint_velocities = None return control_action