Source code for omni.isaac.lab.managers.event_manager

# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Event manager for orchestrating operations based on different simulation events."""

from __future__ import annotations

import torch
from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING

import omni.log

from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import EventTermCfg

if TYPE_CHECKING:
    from omni.isaac.lab.envs import ManagerBasedEnv


[docs]class EventManager(ManagerBase): """Manager for orchestrating operations based on different simulation events. The event manager applies operations to the environment based on different simulation events. For example, changing the masses of objects or their friction coefficients during initialization/ reset, or applying random pushes to the robot at a fixed interval of steps. The user can specify several modes of events to fine-tune the behavior based on when to apply the event. The event terms are parsed from a config class containing the manager's settings and each term's parameters. Each event term should instantiate the :class:`EventTermCfg` class. Event terms can be grouped by their mode. The mode is a user-defined string that specifies when the event term should be applied. This provides the user complete control over when event terms should be applied. For a typical training process, you may want to apply events in the following modes: - "startup": Event is applied once at the beginning of the training. - "reset": Event is applied at every reset. - "interval": Event is applied at pre-specified intervals of time. However, you can also define your own modes and use them in the training process as you see fit. For this you will need to add the triggering of that mode in the environment implementation as well. .. note:: The triggering of operations corresponding to the mode ``"interval"`` are the only mode that are directly handled by the manager itself. The other modes are handled by the environment implementation. """ _env: ManagerBasedEnv """The environment instance."""
[docs] def __init__(self, cfg: object, env: ManagerBasedEnv): """Initialize the event manager. Args: cfg: A configuration object or dictionary (``dict[str, EventTermCfg]``). env: An environment object. """ # create buffers to parse and store terms self._mode_term_names: dict[str, list[str]] = dict() self._mode_term_cfgs: dict[str, list[EventTermCfg]] = dict() self._mode_class_term_cfgs: dict[str, list[EventTermCfg]] = dict() # call the base class (this will parse the terms config) super().__init__(cfg, env)
def __str__(self) -> str: """Returns: A string representation for event manager.""" msg = f"<EventManager> contains {len(self._mode_term_names)} active terms.\n" # add info on each mode for mode in self._mode_term_names: # create table for term information table = PrettyTable() table.title = f"Active Event Terms in Mode: '{mode}'" # add table headers based on mode if mode == "interval": table.field_names = ["Index", "Name", "Interval time range (s)"] table.align["Name"] = "l" for index, (name, cfg) in enumerate(zip(self._mode_term_names[mode], self._mode_term_cfgs[mode])): table.add_row([index, name, cfg.interval_range_s]) else: table.field_names = ["Index", "Name"] table.align["Name"] = "l" for index, name in enumerate(self._mode_term_names[mode]): table.add_row([index, name]) # convert table to string msg += table.get_string() msg += "\n" return msg """ Properties. """ @property def active_terms(self) -> dict[str, list[str]]: """Name of active event terms. The keys are the modes of event and the values are the names of the event terms. """ return self._mode_term_names @property def available_modes(self) -> list[str]: """Modes of events.""" return list(self._mode_term_names.keys()) """ Operations. """
[docs] def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]: # call all terms that are classes for mode_cfg in self._mode_class_term_cfgs.values(): for term_cfg in mode_cfg: term_cfg.func.reset(env_ids=env_ids) # nothing to log here return {}
[docs] def apply( self, mode: str, env_ids: Sequence[int] | None = None, dt: float | None = None, global_env_step_count: int | None = None, ): """Calls each event term in the specified mode. This function iterates over all the event terms in the specified mode and calls the function corresponding to the term. The function is called with the environment instance and the environment indices to apply the event to. For the "interval" mode, the function is called when the time interval has passed. This requires specifying the time step of the environment. For the "reset" mode, the function is called when the mode is "reset" and the total number of environment steps that have happened since the last trigger of the function is equal to its configured parameter for the number of environment steps between resets. Args: mode: The mode of event. env_ids: The indices of the environments to apply the event to. Defaults to None, in which case the event is applied to all environments when applicable. dt: The time step of the environment. This is only used for the "interval" mode. Defaults to None to simplify the call for other modes. global_env_step_count: The total number of environment steps that have happened. This is only used for the "reset" mode. Defaults to None to simplify the call for other modes. Raises: ValueError: If the mode is ``"interval"`` and the time step is not provided. ValueError: If the mode is ``"interval"`` and the environment indices are provided. This is an undefined behavior as the environment indices are computed based on the time left for each environment. ValueError: If the mode is ``"reset"`` and the total number of environment steps that have happened is not provided. """ # check if mode is valid if mode not in self._mode_term_names: omni.log.warn(f"Event mode '{mode}' is not defined. Skipping event.") return # check if mode is interval and dt is not provided if mode == "interval" and dt is None: raise ValueError(f"Event mode '{mode}' requires the time-step of the environment.") if mode == "interval" and env_ids is not None: raise ValueError( f"Event mode '{mode}' does not require environment indices. This is an undefined behavior" " as the environment indices are computed based on the time left for each environment." ) # check if mode is reset and env step count is not provided if mode == "reset" and global_env_step_count is None: raise ValueError(f"Event mode '{mode}' requires the total number of environment steps to be provided.") # iterate over all the event terms for index, term_cfg in enumerate(self._mode_term_cfgs[mode]): if mode == "interval": # extract time left for this term time_left = self._interval_term_time_left[index] # update the time left for each environment time_left -= dt # check if the interval has passed and sample a new interval # note: we compare with a small value to handle floating point errors if term_cfg.is_global_time: if time_left < 1e-6: lower, upper = term_cfg.interval_range_s sampled_interval = torch.rand(1) * (upper - lower) + lower self._interval_term_time_left[index][:] = sampled_interval # call the event term (with None for env_ids) term_cfg.func(self._env, None, **term_cfg.params) else: valid_env_ids = (time_left < 1e-6).nonzero().flatten() if len(valid_env_ids) > 0: lower, upper = term_cfg.interval_range_s sampled_time = torch.rand(len(valid_env_ids), device=self.device) * (upper - lower) + lower self._interval_term_time_left[index][valid_env_ids] = sampled_time # call the event term term_cfg.func(self._env, valid_env_ids, **term_cfg.params) elif mode == "reset": # obtain the minimum step count between resets min_step_count = term_cfg.min_step_count_between_reset # resolve the environment indices if env_ids is None: env_ids = slice(None) # We bypass the trigger mechanism if min_step_count is zero, i.e. apply term on every reset call. # This should avoid the overhead of checking the trigger condition. if min_step_count == 0: self._reset_term_last_triggered_step_id[index][env_ids] = global_env_step_count self._reset_term_last_triggered_once[index][env_ids] = True # call the event term with the environment indices term_cfg.func(self._env, env_ids, **term_cfg.params) else: # extract last reset step for this term last_triggered_step = self._reset_term_last_triggered_step_id[index][env_ids] triggered_at_least_once = self._reset_term_last_triggered_once[index][env_ids] # compute the steps since last reset steps_since_triggered = global_env_step_count - last_triggered_step # check if the term can be applied after the minimum step count between triggers has passed valid_trigger = steps_since_triggered >= min_step_count # check if the term has not been triggered yet (in that case, we trigger it at least once) # this is usually only needed at the start of the environment valid_trigger |= (last_triggered_step == 0) & ~triggered_at_least_once # select the valid environment indices based on the trigger if env_ids == slice(None): valid_env_ids = valid_trigger.nonzero().flatten() else: valid_env_ids = env_ids[valid_trigger] # reset the last reset step for each environment to the current env step count if len(valid_env_ids) > 0: self._reset_term_last_triggered_once[index][valid_env_ids] = True self._reset_term_last_triggered_step_id[index][valid_env_ids] = global_env_step_count # call the event term term_cfg.func(self._env, valid_env_ids, **term_cfg.params) else: # call the event term term_cfg.func(self._env, env_ids, **term_cfg.params)
""" Operations - Term settings. """
[docs] def set_term_cfg(self, term_name: str, cfg: EventTermCfg): """Sets the configuration of the specified term into the manager. The method finds the term by name by searching through all the modes. It then updates the configuration of the term with the first matching name. Args: term_name: The name of the event term. cfg: The configuration for the event term. Raises: ValueError: If the term name is not found. """ term_found = False for mode, terms in self._mode_term_names.items(): if term_name in terms: self._mode_term_cfgs[mode][terms.index(term_name)] = cfg term_found = True break if not term_found: raise ValueError(f"Event term '{term_name}' not found.")
[docs] def get_term_cfg(self, term_name: str) -> EventTermCfg: """Gets the configuration for the specified term. The method finds the term by name by searching through all the modes. It then returns the configuration of the term with the first matching name. Args: term_name: The name of the event term. Returns: The configuration of the event term. Raises: ValueError: If the term name is not found. """ for mode, terms in self._mode_term_names.items(): if term_name in terms: return self._mode_term_cfgs[mode][terms.index(term_name)] raise ValueError(f"Event term '{term_name}' not found.")
""" Helper functions. """ def _prepare_terms(self): # buffer to store the time left for "interval" mode # if interval is global, then it is a single value, otherwise it is per environment self._interval_term_time_left: list[torch.Tensor] = list() # buffer to store the step count when the term was last triggered for each environment for "reset" mode self._reset_term_last_triggered_step_id: list[torch.Tensor] = list() self._reset_term_last_triggered_once: list[torch.Tensor] = list() # check if config is dict already if isinstance(self.cfg, dict): cfg_items = self.cfg.items() else: cfg_items = self.cfg.__dict__.items() # iterate over all the terms for term_name, term_cfg in cfg_items: # check for non config if term_cfg is None: continue # check for valid config type if not isinstance(term_cfg, EventTermCfg): raise TypeError( f"Configuration for the term '{term_name}' is not of type EventTermCfg." f" Received: '{type(term_cfg)}'." ) if term_cfg.mode != "reset" and term_cfg.min_step_count_between_reset != 0: omni.log.warn( f"Event term '{term_name}' has 'min_step_count_between_reset' set to a non-zero value" " but the mode is not 'reset'. Ignoring the 'min_step_count_between_reset' value." ) # resolve common parameters self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2) # check if mode is a new mode if term_cfg.mode not in self._mode_term_names: # add new mode self._mode_term_names[term_cfg.mode] = list() self._mode_term_cfgs[term_cfg.mode] = list() self._mode_class_term_cfgs[term_cfg.mode] = list() # add term name and parameters self._mode_term_names[term_cfg.mode].append(term_name) self._mode_term_cfgs[term_cfg.mode].append(term_cfg) # check if the term is a class if isinstance(term_cfg.func, ManagerTermBase): self._mode_class_term_cfgs[term_cfg.mode].append(term_cfg) # resolve the mode of the events # -- interval mode if term_cfg.mode == "interval": if term_cfg.interval_range_s is None: raise ValueError( f"Event term '{term_name}' has mode 'interval' but 'interval_range_s' is not specified." ) # sample the time left for global if term_cfg.is_global_time: lower, upper = term_cfg.interval_range_s time_left = torch.rand(1) * (upper - lower) + lower self._interval_term_time_left.append(time_left) else: # sample the time left for each environment lower, upper = term_cfg.interval_range_s time_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower self._interval_term_time_left.append(time_left) # -- reset mode elif term_cfg.mode == "reset": if term_cfg.min_step_count_between_reset < 0: raise ValueError( f"Event term '{term_name}' has mode 'reset' but 'min_step_count_between_reset' is" f" negative: {term_cfg.min_step_count_between_reset}. Please provide a non-negative value." ) # initialize the current step count for each environment to zero step_count = torch.zeros(self.num_envs, device=self.device, dtype=torch.int32) self._reset_term_last_triggered_step_id.append(step_count) # initialize the trigger flag for each environment to zero no_trigger = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) self._reset_term_last_triggered_once.append(no_trigger)