# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import json
import os
from collections.abc import Iterable
import h5py
import numpy as np
import torch
from isaaclab.utils.math import convert_quat
from .dataset_file_handler_base import DatasetFileHandlerBase
from .episode_data import EpisodeData
# Current dataset format version
# Version 1: XYZW quaternion format (current)
# Version 0 (or missing): Legacy WXYZ quaternion format
DATASET_FORMAT_VERSION = 1
def convert_pose_quat_wxyz_to_xyzw(pose: np.ndarray) -> np.ndarray:
"""Convert pose quaternion from WXYZ format to XYZW format.
The pose is expected to have shape (..., 7) where the first 3 elements are position
and the last 4 elements are the quaternion.
Args:
pose: Pose array with shape (..., 7) where quaternion is in WXYZ format.
Returns:
Pose array with shape (..., 7) where quaternion is in XYZW format.
"""
position = pose[..., :3]
quat_wxyz = pose[..., 3:7]
quat_xyzw = convert_quat(quat_wxyz, to="xyzw")
return np.concatenate([position, quat_xyzw], axis=-1)
[docs]
class HDF5DatasetFileHandler(DatasetFileHandlerBase):
"""HDF5 dataset file handler for storing and loading episode data."""
def __init__(self):
"""Initializes the HDF5 dataset file handler."""
self._hdf5_file_stream = None
self._hdf5_data_group = None
self._demo_count = 0
self._env_args = {}
[docs]
def open(self, file_path: str, mode: str = "r"):
"""Open an existing dataset file."""
if self._hdf5_file_stream is not None:
raise RuntimeError("HDF5 dataset file stream is already in use")
self._hdf5_file_stream = h5py.File(file_path, mode)
self._hdf5_data_group = self._hdf5_file_stream["data"]
self._demo_count = len(self._hdf5_data_group)
[docs]
def create(self, file_path: str, env_name: str = None):
"""Create a new dataset file."""
if self._hdf5_file_stream is not None:
raise RuntimeError("HDF5 dataset file stream is already in use")
if not file_path.endswith(".hdf5"):
file_path += ".hdf5"
dir_path = os.path.dirname(file_path)
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
self._hdf5_file_stream = h5py.File(file_path, "w")
# Set the dataset format version
self._hdf5_file_stream.attrs["format_version"] = DATASET_FORMAT_VERSION
# set up a data group in the file
self._hdf5_data_group = self._hdf5_file_stream.create_group("data")
self._hdf5_data_group.attrs["total"] = 0
self._demo_count = 0
# set environment arguments
# the environment type (we use gym environment type) is set to be compatible with robomimic
# Ref: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/envs/env_base.py#L15
env_name = env_name if env_name is not None else ""
self.add_env_args({"env_name": env_name, "type": 2})
def __del__(self):
"""Destructor for the file handler."""
self.close()
"""
Properties
"""
[docs]
def add_env_args(self, env_args: dict):
"""Add environment arguments to the dataset."""
self._raise_if_not_initialized()
self._env_args.update(env_args)
self._hdf5_data_group.attrs["env_args"] = json.dumps(self._env_args)
[docs]
def set_env_name(self, env_name: str):
"""Set the environment name."""
self._raise_if_not_initialized()
self.add_env_args({"env_name": env_name})
[docs]
def get_env_name(self) -> str | None:
"""Get the environment name."""
self._raise_if_not_initialized()
env_args = json.loads(self._hdf5_data_group.attrs["env_args"])
if "env_name" in env_args:
return env_args["env_name"]
return None
[docs]
def get_episode_names(self) -> Iterable[str]:
"""Get the names of the episodes in the file."""
self._raise_if_not_initialized()
return self._hdf5_data_group.keys()
[docs]
def get_num_episodes(self) -> int:
"""Get number of episodes in the file."""
return self._demo_count
@property
def demo_count(self) -> int:
"""The number of demos collected so far."""
return self._demo_count
"""
Operations.
"""
[docs]
def load_episode(
self, episode_name: str, device: str, convert_legacy_quat: bool | None = None
) -> EpisodeData | None:
"""Load episode data from the file.
Args:
episode_name: Name of the episode to load.
device: Device to load tensors to.
convert_legacy_quat: If True, convert quaternions from legacy WXYZ to XYZW format.
If None (default), auto-detect based on dataset version.
Returns:
The loaded episode data, or None if the episode doesn't exist.
"""
self._raise_if_not_initialized()
if episode_name not in self._hdf5_data_group:
return None
# Auto-detect if conversion is needed
if convert_legacy_quat is None:
convert_legacy_quat = self.is_legacy_quaternion_format()
episode = EpisodeData()
h5_episode_group = self._hdf5_data_group[episode_name]
def load_dataset_helper(group, path=""):
"""Helper method to load dataset that contains recursive dict objects."""
data = {}
for key in group:
current_path = f"{path}/{key}" if path else key
if isinstance(group[key], h5py.Group):
data[key] = load_dataset_helper(group[key], current_path)
else:
# Converting group[key] to numpy array greatly improves the performance
# when converting to torch tensor
np_data = np.array(group[key])
# Convert legacy quaternions if needed
if convert_legacy_quat and key == "root_pose" and np_data.shape[-1] == 7:
np_data = convert_pose_quat_wxyz_to_xyzw(np_data)
data[key] = torch.tensor(np_data, device=device)
return data
episode.data = load_dataset_helper(h5_episode_group)
if "seed" in h5_episode_group.attrs:
episode.seed = h5_episode_group.attrs["seed"]
if "success" in h5_episode_group.attrs:
episode.success = h5_episode_group.attrs["success"]
episode.env_id = self.get_env_name()
return episode
[docs]
def write_episode(self, episode: EpisodeData, demo_id: int | None = None):
"""Add an episode to the dataset.
Args:
episode: The episode data to add.
demo_id: Custom index for the episode. If None, uses default index.
"""
self._raise_if_not_initialized()
if episode.is_empty():
return
# Use custom demo id if provided, otherwise use default naming
if demo_id is not None:
episode_group_name = f"demo_{demo_id}"
else:
episode_group_name = f"demo_{self._demo_count}"
# create episode group with the specified name
if episode_group_name in self._hdf5_data_group:
raise ValueError(f"Episode group '{episode_group_name}' already exists in the dataset")
h5_episode_group = self._hdf5_data_group.create_group(episode_group_name)
# store number of steps taken
if "actions" in episode.data:
h5_episode_group.attrs["num_samples"] = len(episode.data["actions"])
else:
h5_episode_group.attrs["num_samples"] = 0
if episode.seed is not None:
h5_episode_group.attrs["seed"] = episode.seed
if episode.success is not None:
h5_episode_group.attrs["success"] = episode.success
def create_dataset_helper(group, key, value):
"""Helper method to create dataset that contains recursive dict objects."""
if isinstance(value, dict):
key_group = group.create_group(key)
for sub_key, sub_value in value.items():
create_dataset_helper(key_group, sub_key, sub_value)
else:
group.create_dataset(key, data=value.cpu().numpy(), compression="gzip")
for key, value in episode.data.items():
create_dataset_helper(h5_episode_group, key, value)
# increment total step counts
self._hdf5_data_group.attrs["total"] += h5_episode_group.attrs["num_samples"]
# Only increment demo count if using default indexing
if demo_id is None:
# increment total demo counts
self._demo_count += 1
[docs]
def flush(self):
"""Flush the episode data to disk."""
self._raise_if_not_initialized()
self._hdf5_file_stream.flush()
[docs]
def close(self):
"""Close the dataset file handler."""
if self._hdf5_file_stream is not None:
self._hdf5_file_stream.close()
self._hdf5_file_stream = None
def _raise_if_not_initialized(self):
"""Raise an error if the dataset file handler is not initialized."""
if self._hdf5_file_stream is None:
raise RuntimeError("HDF5 dataset file stream is not initialized")
[docs]
@staticmethod
def convert_dataset_to_xyzw(input_path: str, output_path: str | None = None) -> str:
"""Convert a legacy dataset from WXYZ to XYZW quaternion format.
This method reads a dataset file, converts all quaternions from the legacy WXYZ format
to the current XYZW format, and writes the result to a new file.
Args:
input_path: Path to the input dataset file (legacy WXYZ format).
output_path: Path for the output dataset file. If None, appends '_xyzw' to input filename.
Returns:
Path to the converted dataset file.
Raises:
FileNotFoundError: If the input file does not exist.
ValueError: If the dataset is already in XYZW format.
"""
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input dataset file not found: {input_path}")
# Generate output path if not provided
if output_path is None:
base, ext = os.path.splitext(input_path)
output_path = f"{base}_xyzw{ext}"
def convert_group_quaternions(src_group, dst_group):
"""Recursively copy and convert quaternions in groups."""
# Copy attributes
for attr_name, attr_value in src_group.attrs.items():
dst_group.attrs[attr_name] = attr_value
# Process items
for key in src_group:
if isinstance(src_group[key], h5py.Group):
# Recursively handle groups
dst_subgroup = dst_group.create_group(key)
convert_group_quaternions(src_group[key], dst_subgroup)
else:
# Handle datasets
data = np.array(src_group[key])
# Convert root_pose quaternions
if key == "root_pose" and data.shape[-1] == 7:
data = convert_pose_quat_wxyz_to_xyzw(data)
# Preserve compression settings if possible
compression = src_group[key].compression
dst_group.create_dataset(key, data=data, compression=compression)
with h5py.File(input_path, "r") as src_file:
# Check if already converted
if "format_version" in src_file.attrs and src_file.attrs["format_version"] >= DATASET_FORMAT_VERSION:
raise ValueError(f"Dataset is already in XYZW format (version {src_file.attrs['format_version']})")
with h5py.File(output_path, "w") as dst_file:
# Set the new format version
dst_file.attrs["format_version"] = DATASET_FORMAT_VERSION
# Copy and convert all data
convert_group_quaternions(src_file, dst_file)
return output_path