Source code for isaaclab.utils.datasets.episode_data

# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

# Copyright (c) 2024-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch


[docs] class EpisodeData: """Class to store episode data.""" def __init__(self) -> None: """Initializes episode data class.""" self._data = dict() self._next_action_index = 0 self._next_state_index = 0 self._next_joint_target_index = 0 self._seed = None self._env_id = None self._success = None @property def data(self): """Returns the episode data.""" return self._data @data.setter def data(self, data: dict): """Set the episode data.""" self._data = data @property def seed(self): """Returns the random number generator seed.""" return self._seed @seed.setter def seed(self, seed: int): """Set the random number generator seed.""" self._seed = seed @property def env_id(self): """Returns the environment ID.""" return self._env_id @env_id.setter def env_id(self, env_id: int): """Set the environment ID.""" self._env_id = env_id @property def next_action_index(self): """Returns the next action index.""" return self._next_action_index @next_action_index.setter def next_action_index(self, index: int): """Set the next action index.""" self._next_action_index = index @property def next_state_index(self): """Returns the next state index.""" return self._next_state_index @next_state_index.setter def next_state_index(self, index: int): """Set the next state index.""" self._next_state_index = index @property def success(self): """Returns the success value.""" return self._success @success.setter def success(self, success: bool): """Set the success value.""" self._success = success
[docs] def is_empty(self): """Check if the episode data is empty.""" return not bool(self._data)
[docs] def add(self, key: str, value: torch.Tensor | dict): """Add a key-value pair to the dataset. The key can be nested by using the "/" character. For example: "obs/joint_pos". Args: key: The key name. value: The corresponding value of tensor type or of dict type. """ # check datatype if isinstance(value, dict): for sub_key, sub_value in value.items(): self.add(f"{key}/{sub_key}", sub_value) return sub_keys = key.split("/") current_dataset_pointer = self._data for sub_key_index in range(len(sub_keys)): if sub_key_index == len(sub_keys) - 1: # Add value to the final dict layer # Use lists to prevent slow tensor copy during concatenation if sub_keys[sub_key_index] not in current_dataset_pointer: current_dataset_pointer[sub_keys[sub_key_index]] = [value.clone()] else: current_dataset_pointer[sub_keys[sub_key_index]].append(value.clone()) break # key index if sub_keys[sub_key_index] not in current_dataset_pointer: current_dataset_pointer[sub_keys[sub_key_index]] = dict() current_dataset_pointer = current_dataset_pointer[sub_keys[sub_key_index]]
[docs] def get_initial_state(self) -> torch.Tensor | None: """Get the initial state from the dataset.""" if "initial_state" not in self._data: return None return self._data["initial_state"]
[docs] def get_action(self, action_index) -> torch.Tensor | None: """Get the action of the specified index from the dataset.""" if "actions" not in self._data: return None if action_index >= len(self._data["actions"]): return None return self._data["actions"][action_index]
[docs] def get_next_action(self) -> torch.Tensor | None: """Get the next action from the dataset.""" action = self.get_action(self._next_action_index) if action is not None: self._next_action_index += 1 return action
[docs] def get_state(self, state_index) -> dict | None: """Get the state of the specified index from the dataset.""" if "states" not in self._data: return None states = self._data["states"] def get_state_helper(states, state_index) -> dict | torch.Tensor | None: if isinstance(states, dict): output_state = dict() for key, value in states.items(): output_state[key] = get_state_helper(value, state_index) if output_state[key] is None: return None elif isinstance(states, torch.Tensor): if state_index >= len(states): return None output_state = states[state_index, None] else: raise ValueError(f"Invalid state type: {type(states)}") return output_state output_state = get_state_helper(states, state_index) return output_state
[docs] def get_next_state(self) -> dict | None: """Get the next state from the dataset.""" state = self.get_state(self._next_state_index) if state is not None: self._next_state_index += 1 return state
[docs] def get_joint_target(self, joint_target_index) -> dict | torch.Tensor | None: """Get the joint target of the specified index from the dataset.""" if "joint_targets" not in self._data: return None joint_targets = self._data["joint_targets"] def get_joint_target_helper(joint_targets, joint_target_index) -> dict | torch.Tensor | None: if isinstance(joint_targets, dict): output_joint_targets = dict() for key, value in joint_targets.items(): output_joint_targets[key] = get_joint_target_helper(value, joint_target_index) if output_joint_targets[key] is None: return None elif isinstance(joint_targets, torch.Tensor): if joint_target_index >= len(joint_targets): return None output_joint_targets = joint_targets[joint_target_index] else: raise ValueError(f"Invalid joint target type: {type(joint_targets)}") return output_joint_targets output_joint_targets = get_joint_target_helper(joint_targets, joint_target_index) return output_joint_targets
[docs] def get_next_joint_target(self) -> dict | torch.Tensor | None: """Get the next joint target from the dataset.""" joint_target = self.get_joint_target(self._next_joint_target_index) if joint_target is not None: self._next_joint_target_index += 1 return joint_target
[docs] def pre_export(self): """Prepare data for export by converting lists to tensors.""" def pre_export_helper(data): for key, value in data.items(): if isinstance(value, list): data[key] = torch.stack(value) elif isinstance(value, dict): pre_export_helper(value) pre_export_helper(self._data)