# Copyright (c) 2022-2025, The Isaac Lab Project Developers.# All rights reserved.## SPDX-License-Identifier: BSD-3-Clause"""Sub-module that provides a wrapper around the Python 3.7 onwards ``dataclasses`` module."""importinspectimporttypesfromcollections.abcimportCallablefromcopyimportdeepcopyfromdataclassesimportMISSING,Field,dataclass,field,replacefromtypingimportAny,ClassVarfrom.dictimportclass_to_dict,update_class_from_dict_CONFIGCLASS_METHODS=["to_dict","from_dict","replace","copy","validate"]"""List of class methods added at runtime to dataclass.""""""Wrapper around dataclass."""def__dataclass_transform__():"""Add annotations decorator for PyLance."""returnlambdaa:a
[docs]@__dataclass_transform__()defconfigclass(cls,**kwargs):"""Wrapper around `dataclass` functionality to add extra checks and utilities. As of Python 3.7, the standard dataclasses have two main issues which makes them non-generic for configuration use-cases. These include: 1. Requiring a type annotation for all its members. 2. Requiring explicit usage of :meth:`field(default_factory=...)` to reinitialize mutable variables. This function provides a decorator that wraps around Python's `dataclass`_ utility to deal with the above two issues. It also provides additional helper functions for dictionary <-> class conversion and easily copying class instances. Usage: .. code-block:: python from dataclasses import MISSING from isaaclab.utils.configclass import configclass @configclass class ViewerCfg: eye: list = [7.5, 7.5, 7.5] # field missing on purpose lookat: list = field(default_factory=[0.0, 0.0, 0.0]) @configclass class EnvCfg: num_envs: int = MISSING episode_length: int = 2000 viewer: ViewerCfg = ViewerCfg() # create configuration instance env_cfg = EnvCfg(num_envs=24) # print information as a dictionary print(env_cfg.to_dict()) # create a copy of the configuration env_cfg_copy = env_cfg.copy() # replace arbitrary fields using keyword arguments env_cfg_copy = env_cfg_copy.replace(num_envs=32) Args: cls: The class to wrap around. **kwargs: Additional arguments to pass to :func:`dataclass`. Returns: The wrapped class. .. _dataclass: https://docs.python.org/3/library/dataclasses.html """# add type annotations_add_annotation_types(cls)# add field factory_process_mutable_types(cls)# copy mutable members# note: we check if user defined __post_init__ function exists and augment it with our ownifhasattr(cls,"__post_init__"):setattr(cls,"__post_init__",_combined_function(cls.__post_init__,_custom_post_init))else:setattr(cls,"__post_init__",_custom_post_init)# add helper functions for dictionary conversionsetattr(cls,"to_dict",_class_to_dict)setattr(cls,"from_dict",_update_class_from_dict)setattr(cls,"replace",_replace_class_with_kwargs)setattr(cls,"copy",_copy_class)setattr(cls,"validate",_validate)# wrap around dataclasscls=dataclass(cls,**kwargs)# return wrapped classreturncls
"""Dictionary <-> Class operations.These are redefined here to add new docstrings."""def_class_to_dict(obj:object)->dict[str,Any]:"""Convert an object into dictionary recursively. Args: obj: The object to convert. Returns: Converted dictionary mapping. """returnclass_to_dict(obj)def_update_class_from_dict(obj,data:dict[str,Any])->None:"""Reads a dictionary and sets object variables recursively. This function performs in-place update of the class member attributes. Args: obj: The object to update. data: Input (nested) dictionary to update from. Raises: TypeError: When input is not a dictionary. ValueError: When dictionary has a value that does not match default config type. KeyError: When dictionary has a key that does not exist in the default config type. """update_class_from_dict(obj,data,_ns="")def_replace_class_with_kwargs(obj:object,**kwargs)->object:"""Return a new object replacing specified fields with new values. This is especially useful for frozen classes. Example usage: .. code-block:: python @configclass(frozen=True) class C: x: int y: int c = C(1, 2) c1 = c.replace(x=3) assert c1.x == 3 and c1.y == 2 Args: obj: The object to replace. **kwargs: The fields to replace and their new values. Returns: The new object. """returnreplace(obj,**kwargs)def_copy_class(obj:object)->object:"""Return a new object with the same fields as the original."""returnreplace(obj)"""Private helper functions."""def_add_annotation_types(cls):"""Add annotations to all elements in the dataclass. By definition in Python, a field is defined as a class variable that has a type annotation. In case type annotations are not provided, dataclass ignores those members when :func:`__dict__()` is called. This function adds these annotations to the class variable to prevent any issues in case the user forgets to specify the type annotation. This makes the following a feasible operation: @dataclass class State: pos = (0.0, 0.0, 0.0) ^^ If the function is NOT used, the following type-error is returned: TypeError: 'pos' is a field but has no type annotation """# get type hintshints={}# iterate over class inheritance# we add annotations from base classes firstforbaseinreversed(cls.__mro__):# check if base is objectifbaseisobject:continue# get base class annotationsann=base.__dict__.get("__annotations__",{})# directly add all annotations from base classhints.update(ann)# iterate over base class members# Note: Do not change this to dir(base) since it orders the members alphabetically.# This is not desirable since the order of the members is important in some cases.forkeyinbase.__dict__:# get class membervalue=getattr(base,key)# skip membersif_skippable_class_member(key,value,hints):continue# add type annotations for members that don't have explicit type annotations# for these, we deduce the type from the default valueifnotisinstance(value,type):ifkeynotinhints:# check if var type is not MISSING# we cannot deduce type from MISSING!ifvalueisMISSING:raiseTypeError(f"Missing type annotation for '{key}' in class '{cls.__name__}'."" Please add a type annotation or set a default value.")# add type annotationhints[key]=type(value)elifkey!=value.__name__:# note: we don't want to add type annotations for nested configclass. Thus, we check if# the name of the type matches the name of the variable.# since Python 3.10, type hints are stored as stringshints[key]=f"type[{value.__name__}]"# Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from# `cls.__annotations__` because of inheritance.cls.__annotations__=cls.__dict__.get("__annotations__",{})cls.__annotations__=hintsdef_validate(obj:object,prefix:str="")->list[str]:"""Check the validity of configclass object. This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING entries. Args: obj: The object to check. prefix: The prefix to add to the missing fields. Defaults to ''. Returns: A list of missing fields. Raises: TypeError: When the object is not a valid configuration object. """missing_fields=[]iftype(obj)istype(MISSING):missing_fields.append(prefix)returnmissing_fieldselifisinstance(obj,(list,tuple)):forindex,iteminenumerate(obj):current_path=f"{prefix}[{index}]"missing_fields.extend(_validate(item,prefix=current_path))returnmissing_fieldselifisinstance(obj,dict):obj_dict=objelifhasattr(obj,"__dict__"):obj_dict=obj.__dict__else:returnmissing_fieldsforkey,valueinobj_dict.items():# disregard builtin attributesifkey.startswith("__"):continuecurrent_path=f"{prefix}.{key}"ifprefixelsekeymissing_fields.extend(_validate(value,prefix=current_path))# raise an error only once at the top-level callifprefix==""andmissing_fields:formatted_message="\n".join(f" - {field}"forfieldinmissing_fields)raiseTypeError(f"Missing values detected in object {obj.__class__.__name__} for the following"f" fields:\n{formatted_message}\n")returnmissing_fieldsdef_process_mutable_types(cls):"""Initialize all mutable elements through :obj:`dataclasses.Field` to avoid unnecessary complaints. By default, dataclass requires usage of :obj:`field(default_factory=...)` to reinitialize mutable objects every time a new class instance is created. If a member has a mutable type and it is created without specifying the `field(default_factory=...)`, then Python throws an error requiring the usage of `default_factory`. Additionally, Python only explicitly checks for field specification when the type is a list, set or dict. This misses the use-case where the type is class itself. Thus, the code silently carries a bug with it which can lead to undesirable effects. This function deals with this issue This makes the following a feasible operation: @dataclass class State: pos: list = [0.0, 0.0, 0.0] ^^ If the function is NOT used, the following value-error is returned: ValueError: mutable default <class 'list'> for field pos is not allowed: use default_factory """# note: Need to set this up in the same order as annotations. Otherwise, it# complains about missing positional arguments.ann=cls.__dict__.get("__annotations__",{})# iterate over all class members and store them in a dictionaryclass_members={}forbaseinreversed(cls.__mro__):# check if base is objectifbaseisobject:continue# iterate over base class membersforkeyinbase.__dict__:# get class memberf=getattr(base,key)# skip membersif_skippable_class_member(key,f):continue# store class member if it is not a type or if it is already present in annotationsifnotisinstance(f,type)orkeyinann:class_members[key]=f# iterate over base class data fields# in previous call, things that became a dataclass field were removed from class members# so we need to add them back here as a dataclass field directlyforkey,finbase.__dict__.get("__dataclass_fields__",{}).items():# store class memberifnotisinstance(f,type):class_members[key]=f# check that all annotations are present in class members# note: mainly for debugging purposesiflen(class_members)!=len(ann):raiseValueError(f"In class '{cls.__name__}', number of annotations ({len(ann)}) does not match number of class members"f" ({len(class_members)}). Please check that all class members have type annotations and/or a default"" value. If you don't want to specify a default value, please use the literal `dataclasses.MISSING`.")# iterate over annotations and add field factory for mutable typesforkeyinann:# find matching field in classvalue=class_members.get(key,MISSING)# check if key belongs to ClassVar# in that case, we cannot use default_factory!origin=getattr(ann[key],"__origin__",None)iforiginisClassVar:continue# check if f is MISSING# note: commented out for now since it causes issue with inheritance# of dataclasses when parent have some positional and some keyword arguments.# Ref: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses# TODO: check if this is fixed in Python 3.10# if f is MISSING:# continueifisinstance(value,Field):setattr(cls,key,value)elifnotisinstance(value,type):# create field factory for mutable typesvalue=field(default_factory=_return_f(value))setattr(cls,key,value)def_custom_post_init(obj):"""Deepcopy all elements to avoid shared memory issues for mutable objects in dataclasses initialization. This function is called explicitly instead of as a part of :func:`_process_mutable_types()` to prevent mapping proxy type i.e. a read only proxy for mapping objects. The error is thrown when using hierarchical data-classes for configuration. """forkeyindir(obj):# skip dunder membersifkey.startswith("__"):continue# get data membervalue=getattr(obj,key)# check annotationann=obj.__class__.__dict__.get(key)# duplicate data members that are mutableifnotcallable(value)andnotisinstance(ann,property):setattr(obj,key,deepcopy(value))def_combined_function(f1:Callable,f2:Callable)->Callable:"""Combine two functions into one. Args: f1: The first function. f2: The second function. Returns: The combined function. """def_combined(*args,**kwargs):# call both functionsf1(*args,**kwargs)f2(*args,**kwargs)return_combined"""Helper functions"""def_skippable_class_member(key:str,value:Any,hints:dict|None=None)->bool:"""Check if the class member should be skipped in configclass processing. The following members are skipped: * Dunder members: ``__name__``, ``__module__``, ``__qualname__``, ``__annotations__``, ``__dict__``. * Manually-added special class functions: From :obj:`_CONFIGCLASS_METHODS`. * Members that are already present in the type annotations. * Functions bounded to class object or class. * Properties bounded to class object. Args: key: The class member name. value: The class member value. hints: The type hints for the class. Defaults to None, in which case, the members existence in type hints are not checked. Returns: True if the class member should be skipped, False otherwise. """# skip dunder membersifkey.startswith("__"):returnTrue# skip manually-added special class functionsifkeyin_CONFIGCLASS_METHODS:returnTrue# check if key is already presentifhintsisnotNoneandkeyinhints:returnTrue# skip functions bounded to classifcallable(value):# FIXME: This doesn't yet work for static methods because they are essentially seen as function types.# check for class methodsifisinstance(value,types.MethodType):returnTrue# check for instance methodssignature=inspect.signature(value)if"self"insignature.parametersor"cls"insignature.parameters:returnTrue# skip property methodsifisinstance(value,property):returnTrue# Otherwise, don't skipreturnFalsedef_return_f(f:Any)->Callable[[],Any]:"""Returns default factory function for creating mutable/immutable variables. This function should be used to create default factory functions for variables. Example: .. code-block:: python value = field(default_factory=_return_f(value)) setattr(cls, key, value) """def_wrap():ifisinstance(f,Field):iff.default_factoryisMISSING:returndeepcopy(f.default)else:returnf.default_factoryelse:returndeepcopy(f)return_wrap