Source code for isaaclab_tasks.utils.parse_cfg

# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Sub-module with utilities for parsing and loading configurations."""

from __future__ import annotations

import collections
import importlib
import inspect
import os
import re
from typing import TYPE_CHECKING

import gymnasium as gym
import yaml

if TYPE_CHECKING:
    from isaaclab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg


def _is_preset_cfg(obj: object) -> bool:
    """Return True if *obj* is an instance of a PresetCfg subclass (new typed-field style).

    Uses MRO class-name matching so that this module has no dependency on the
    hydra-gated ``isaaclab_tasks.utils.hydra`` import path.
    """
    return any(cls.__name__ == "PresetCfg" for cls in type(obj).__mro__)


def _is_old_style_preset(obj: object) -> bool:
    """Return True if *obj* is an old-style preset wrapper (has ``presets`` dict with a ``'default'`` key)."""
    presets = getattr(obj, "presets", None)
    return hasattr(obj, "__dataclass_fields__") and isinstance(presets, dict) and "default" in presets


def _resolve_presets_to_default(cfg: object) -> object:
    """Recursively replace preset wrapper fields with their *default* preset.

    Handles two preset patterns used in IsaacLab task configs:

    * **New style** (``PresetCfg`` subclass): typed fields where ``default`` is a class attribute.
    * **Old style** (``presets`` dict): configclass with ``presets: dict[str, Cfg]`` and a ``'default'`` key.

    Both are resolved in-place so the config can be used without a Hydra CLI override (e.g. in tests).
    """
    if not hasattr(cfg, "__dataclass_fields__"):
        return cfg
    for field_name in list(cfg.__dataclass_fields__):
        value = getattr(cfg, field_name, None)
        if value is None:
            continue
        if hasattr(value, "__dataclass_fields__"):
            if _is_preset_cfg(value):
                resolved = value.default
                setattr(cfg, field_name, resolved)
                _resolve_presets_to_default(resolved)
            elif _is_old_style_preset(value):
                resolved = value.presets["default"]
                setattr(cfg, field_name, resolved)
                _resolve_presets_to_default(resolved)
            else:
                _resolve_presets_to_default(value)
        elif isinstance(value, dict):
            for dict_val in value.values():
                if hasattr(dict_val, "__dataclass_fields__"):
                    _resolve_presets_to_default(dict_val)
    return cfg


def apply_named_preset(env_cfg: object, raw_cfg: object, preset_name: str) -> None:
    """Apply a named preset to all preset-wrapper fields in *env_cfg*, guided by *raw_cfg*.

    Walks *raw_cfg* to find preset wrappers (both :class:`PresetCfg` subclasses and
    old-style wrappers with a ``presets`` dict). For each wrapper that contains
    *preset_name*, overrides the corresponding already-resolved field in *env_cfg*.

    This is used in tests to apply a non-default physics preset (e.g. ``'newton'``)
    after :func:`parse_env_cfg` has already resolved all wrappers to ``'default'``.

    Args:
        env_cfg: Resolved env config (from :func:`parse_env_cfg`) to update in-place.
        raw_cfg: Raw env config (from :func:`load_cfg_from_registry`) with preset
            wrappers still intact.
        preset_name: Name of the preset to apply (e.g., ``'newton'``).
    """
    if not hasattr(raw_cfg, "__dataclass_fields__"):
        return
    for field_name in raw_cfg.__dataclass_fields__:
        raw_value = getattr(raw_cfg, field_name, None)
        if raw_value is None:
            continue
        if hasattr(raw_value, "__dataclass_fields__"):
            if _is_preset_cfg(raw_value):
                if hasattr(raw_value, preset_name):
                    resolved = getattr(raw_value, preset_name)
                    setattr(env_cfg, field_name, resolved)
                    apply_named_preset(resolved, resolved, preset_name)
            elif _is_old_style_preset(raw_value):
                if preset_name in raw_value.presets:
                    resolved = raw_value.presets[preset_name]
                    setattr(env_cfg, field_name, resolved)
                    apply_named_preset(resolved, resolved, preset_name)
            else:
                env_value = getattr(env_cfg, field_name, None)
                if env_value is not None and hasattr(env_value, "__dataclass_fields__"):
                    apply_named_preset(env_value, raw_value, preset_name)
        elif isinstance(raw_value, dict):
            env_dict = getattr(env_cfg, field_name, None)
            if not isinstance(env_dict, dict):
                continue
            for key, raw_dict_val in raw_value.items():
                if hasattr(raw_dict_val, "__dataclass_fields__") and key in env_dict:
                    env_dict_val = env_dict[key]
                    if hasattr(env_dict_val, "__dataclass_fields__"):
                        apply_named_preset(env_dict_val, raw_dict_val, preset_name)


[docs] def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | object: """Load default configuration given its entry point from the gym registry. This function loads the configuration object from the gym registry for the given task name. It supports both YAML and Python configuration files. It expects the configuration to be registered in the gym registry as: .. code-block:: python gym.register( id="My-Awesome-Task-v0", ... kwargs={"env_entry_point_cfg": "path.to.config:ConfigClass"}, ) The parsed configuration object for above example can be obtained as: .. code-block:: python from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry cfg = load_cfg_from_registry("My-Awesome-Task-v0", "env_entry_point_cfg") Args: task_name: The name of the environment. entry_point_key: The entry point key to resolve the configuration file. Returns: The parsed configuration object. If the entry point is a YAML file, it is parsed into a dictionary. If the entry point is a Python class, it is instantiated and returned. Raises: ValueError: If the entry point key is not available in the gym registry for the task. """ # obtain the configuration entry point cfg_entry_point = gym.spec(task_name.split(":")[-1]).kwargs.get(entry_point_key) # check if entry point exists if cfg_entry_point is None: # get existing agents and algorithms agents = collections.defaultdict(list) for k in gym.spec(task_name.split(":")[-1]).kwargs: if k.endswith("_cfg_entry_point") and k != "env_cfg_entry_point": spec = ( k.replace("_cfg_entry_point", "") .replace("rl_games", "rl-games") .replace("rsl_rl", "rsl-rl") .split("_") ) agent = spec[0].replace("-", "_") algorithms = [item.upper() for item in (spec[1:] if len(spec) > 1 else ["PPO"])] agents[agent].extend(algorithms) msg = "\nExisting RL library (and algorithms) config entry points: " for agent, algorithms in agents.items(): msg += f"\n |-- {agent}: {', '.join(algorithms)}" # raise error raise ValueError( f"Could not find configuration for the environment: '{task_name}'." f"\nPlease check that the gym registry has the entry point: '{entry_point_key}'." f"{msg if agents else ''}" ) # parse the default config file if isinstance(cfg_entry_point, str) and cfg_entry_point.endswith(".yaml"): if os.path.exists(cfg_entry_point): # absolute path for the config file config_file = cfg_entry_point else: # resolve path to the module location mod_name, file_name = cfg_entry_point.split(":") mod_path = os.path.dirname(importlib.import_module(mod_name).__file__) # obtain the configuration file path config_file = os.path.join(mod_path, file_name) # load the configuration print(f"[INFO]: Parsing configuration from: {config_file}") with open(config_file, encoding="utf-8") as f: cfg = yaml.full_load(f) else: if callable(cfg_entry_point): # resolve path to the module location mod_path = inspect.getfile(cfg_entry_point) # load the configuration cfg_cls = cfg_entry_point() elif isinstance(cfg_entry_point, str): # resolve path to the module location mod_name, attr_name = cfg_entry_point.split(":") mod = importlib.import_module(mod_name) cfg_cls = getattr(mod, attr_name) else: cfg_cls = cfg_entry_point # load the configuration print(f"[INFO]: Parsing configuration from: {cfg_entry_point}") if callable(cfg_cls): cfg = cfg_cls() else: cfg = cfg_cls return cfg
[docs] def parse_env_cfg( task_name: str, device: str = "cuda:0", num_envs: int | None = None, use_fabric: bool | None = None ) -> ManagerBasedRLEnvCfg | DirectRLEnvCfg: """Parse configuration for an environment and override based on inputs. Args: task_name: The name of the environment. device: The device to run the simulation on. Defaults to "cuda:0". num_envs: Number of environments to create. Defaults to None, in which case it is left unchanged. use_fabric: Whether to enable/disable fabric interface. If false, all read/write operations go through USD. This slows down the simulation but allows seeing the changes in the USD through the USD stage. Defaults to None, in which case it is left unchanged. Returns: The parsed configuration object. Raises: RuntimeError: If the configuration for the task is not a class. We assume users always use a class for the environment configuration. """ # load the default configuration cfg = load_cfg_from_registry(task_name.split(":")[-1], "env_cfg_entry_point") # check that it is not a dict # we assume users always use a class for the configuration if isinstance(cfg, dict): raise RuntimeError(f"Configuration for the task: '{task_name}' is not a class. Please provide a class.") # If the top-level cfg is itself a PresetCfg wrapper, resolve to the default preset before # attempting any attribute access (e.g. cfg.sim, cfg.scene). if _is_preset_cfg(cfg): cfg = cfg.default # Resolve any PresetCfg wrappers to their default preset so the config # is usable without a Hydra CLI override (e.g. in tests). # Must happen BEFORE attribute overrides, otherwise overrides on PresetCfg wrapper # fields (e.g. cfg.scene when scene is a PresetCfg) get discarded when the wrapper # is replaced by its .default. _resolve_presets_to_default(cfg) # simulation device cfg.sim.device = device # disable fabric to read/write through USD if use_fabric is not None: cfg.sim.use_fabric = use_fabric # number of environments if num_envs is not None: cfg.scene.num_envs = num_envs return cfg
[docs] def get_checkpoint_path( log_path: str, run_dir: str = ".*", checkpoint: str = ".*", other_dirs: list[str] = None, sort_alpha: bool = True ) -> str: """Get path to the model checkpoint in input directory. The checkpoint file is resolved as: ``<log_path>/<run_dir>/<*other_dirs>/<checkpoint>``, where the :attr:`other_dirs` are intermediate folder names to concatenate. These cannot be regex expressions. If :attr:`run_dir` and :attr:`checkpoint` are regex expressions then the most recent (highest alphabetical order) run and checkpoint are selected. To disable this behavior, set the flag :attr:`sort_alpha` to False. Args: log_path: The log directory path to find models in. run_dir: The regex expression for the name of the directory containing the run. Defaults to the most recent directory created inside :attr:`log_path`. other_dirs: The intermediate directories between the run directory and the checkpoint file. Defaults to None, which implies that checkpoint file is directly under the run directory. checkpoint: The regex expression for the model checkpoint file. Defaults to the most recent torch-model saved in the :attr:`run_dir` directory. sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True. If False, the folders in :attr:`run_dir` are sorted by the last modified time. Returns: The path to the model checkpoint. Raises: ValueError: When no runs are found in the input directory. ValueError: When no checkpoints are found in the input directory. """ # check if runs present in directory try: # find all runs in the directory that math the regex expression runs = [ os.path.join(log_path, run) for run in os.scandir(log_path) if run.is_dir() and re.match(run_dir, run.name) ] # sort matched runs by alphabetical order (latest run should be last) if sort_alpha: runs.sort() else: runs = sorted(runs, key=os.path.getmtime) # create last run file path if other_dirs is not None: run_path = os.path.join(runs[-1], *other_dirs) else: run_path = runs[-1] except IndexError: raise ValueError(f"No runs present in the directory: '{log_path}' match: '{run_dir}'.") # list all model checkpoints in the directory model_checkpoints = [f for f in os.listdir(run_path) if re.match(checkpoint, f)] # check if any checkpoints are present if len(model_checkpoints) == 0: raise ValueError(f"No checkpoints in the directory: '{run_path}' match '{checkpoint}'.") # sort alphabetically while ensuring that *_10 comes after *_9 model_checkpoints.sort(key=lambda m: f"{m:0>15}") # get latest matched checkpoint file checkpoint_file = model_checkpoints[-1] return os.path.join(run_path, checkpoint_file)