Source code for omni.isaac.lab.managers.curriculum_manager
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Curriculum manager for updating environment quantities subject to a training curriculum."""
from __future__ import annotations
import torch
from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import CurriculumTermCfg
if TYPE_CHECKING:
from omni.isaac.lab.envs import ManagerBasedRLEnv
[docs]class CurriculumManager(ManagerBase):
"""Manager to implement and execute specific curricula.
The curriculum manager updates various quantities of the environment subject to a training curriculum by
calling a list of terms. These help stabilize learning by progressively making the learning tasks harder
as the agent improves.
The curriculum terms are parsed from a config class containing the manager's settings and each term's
parameters. Each curriculum term should instantiate the :class:`CurriculumTermCfg` class.
"""
_env: ManagerBasedRLEnv
"""The environment instance."""
[docs] def __init__(self, cfg: object, env: ManagerBasedRLEnv):
"""Initialize the manager.
Args:
cfg: The configuration object or dictionary (``dict[str, CurriculumTermCfg]``)
env: An environment object.
Raises:
TypeError: If curriculum term is not of type :class:`CurriculumTermCfg`.
ValueError: If curriculum term configuration does not satisfy its function signature.
"""
# create buffers to parse and store terms
self._term_names: list[str] = list()
self._term_cfgs: list[CurriculumTermCfg] = list()
self._class_term_cfgs: list[CurriculumTermCfg] = list()
# call the base class constructor (this will parse the terms config)
super().__init__(cfg, env)
# prepare logging
self._curriculum_state = dict()
for term_name in self._term_names:
self._curriculum_state[term_name] = None
def __str__(self) -> str:
"""Returns: A string representation for curriculum manager."""
msg = f"<CurriculumManager> contains {len(self._term_names)} active terms.\n"
# create table for term information
table = PrettyTable()
table.title = "Active Curriculum Terms"
table.field_names = ["Index", "Name"]
# set alignment of table columns
table.align["Name"] = "l"
# add info on each term
for index, name in enumerate(self._term_names):
table.add_row([index, name])
# convert table to string
msg += table.get_string()
msg += "\n"
return msg
"""
Properties.
"""
@property
def active_terms(self) -> list[str]:
"""Name of active curriculum terms."""
return self._term_names
"""
Operations.
"""
[docs] def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
"""Returns the current state of individual curriculum terms.
Note:
This function does not use the environment indices :attr:`env_ids`
and logs the state of all the terms. The argument is only present
to maintain consistency with other classes.
Returns:
Dictionary of curriculum terms and their states.
"""
extras = {}
for term_name, term_state in self._curriculum_state.items():
if term_state is not None:
# deal with dict
if isinstance(term_state, dict):
# each key is a separate state to log
for key, value in term_state.items():
if isinstance(value, torch.Tensor):
value = value.item()
extras[f"Curriculum/{term_name}/{key}"] = value
else:
# log directly if not a dict
if isinstance(term_state, torch.Tensor):
term_state = term_state.item()
extras[f"Curriculum/{term_name}"] = term_state
# reset all the curriculum terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
# return logged information
return extras
[docs] def compute(self, env_ids: Sequence[int] | None = None):
"""Update the curriculum terms.
This function calls each curriculum term managed by the class.
Args:
env_ids: The list of environment IDs to update.
If None, all the environments are updated. Defaults to None.
"""
# resolve environment indices
if env_ids is None:
env_ids = slice(None)
# iterate over all the curriculum terms
for name, term_cfg in zip(self._term_names, self._term_cfgs):
state = term_cfg.func(self._env, env_ids, **term_cfg.params)
self._curriculum_state[name] = state
[docs] def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]:
"""Returns the active terms as iterable sequence of tuples.
The first element of the tuple is the name of the term and the second element is the raw value(s) of the term.
Args:
env_idx: The specific environment to pull the active terms from.
Returns:
The active terms.
"""
terms = []
for term_name, term_state in self._curriculum_state.items():
if term_state is not None:
# deal with dict
data = []
if isinstance(term_state, dict):
# each key is a separate state to log
for key, value in term_state.items():
if isinstance(value, torch.Tensor):
value = value.item()
terms[term_name].append(value)
else:
# log directly if not a dict
if isinstance(term_state, torch.Tensor):
term_state = term_state.item()
data.append(term_state)
terms.append((term_name, data))
return terms
"""
Helper functions.
"""
def _prepare_terms(self):
# check if config is dict already
if isinstance(self.cfg, dict):
cfg_items = self.cfg.items()
else:
cfg_items = self.cfg.__dict__.items()
# iterate over all the terms
for term_name, term_cfg in cfg_items:
# check for non config
if term_cfg is None:
continue
# check if the term is a valid term config
if not isinstance(term_cfg, CurriculumTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type CurriculumTermCfg."
f" Received: '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# add name and config to list
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
# check if the term is a class
if isinstance(term_cfg.func, ManagerTermBase):
self._class_term_cfgs.append(term_cfg)