Source code for isaaclab_mimic.datagen.data_generator

# Copyright (c) 2024-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

"""
Base class for data generator.
"""
import asyncio
import numpy as np
import torch

from isaaclab_mimic.datagen.datagen_info import DatagenInfo
from isaaclab_mimic.datagen.selection_strategy import make_selection_strategy
from isaaclab_mimic.datagen.waypoint import WaypointSequence, WaypointTrajectory

import isaaclab.utils.math as PoseUtils
from isaaclab.envs.mimic_env_cfg import MimicEnvCfg

from .datagen_info_pool import DataGenInfoPool


[docs]class DataGenerator: """ The main data generator object that loads a source dataset, parses it, and generates new trajectories. """
[docs] def __init__( self, env, src_demo_datagen_info_pool=None, dataset_path=None, demo_keys=None, ): """ Args: env (Isaac Lab ManagerBasedEnv instance): environment to use for data generation src_demo_datagen_info_pool (DataGenInfoPool): source demo datagen info pool dataset_path (str): path to hdf5 dataset to use for generation demo_keys (list of str): list of demonstration keys to use in file. If not provided, all demonstration keys will be used. """ self.env = env self.env_cfg = env.cfg assert isinstance(self.env_cfg, MimicEnvCfg) self.dataset_path = dataset_path if len(self.env_cfg.subtask_configs) != 1: raise ValueError("Data generation currently supports only one end-effector.") (self.eef_name,) = self.env_cfg.subtask_configs.keys() (self.subtask_configs,) = self.env_cfg.subtask_configs.values() # sanity check on task spec offset ranges - final subtask should not have any offset randomization assert self.subtask_configs[-1].subtask_term_offset_range[0] == 0 assert self.subtask_configs[-1].subtask_term_offset_range[1] == 0 self.demo_keys = demo_keys if src_demo_datagen_info_pool is not None: self.src_demo_datagen_info_pool = src_demo_datagen_info_pool elif dataset_path is not None: self.src_demo_datagen_info_pool = DataGenInfoPool( env=self.env, env_cfg=self.env_cfg, device=self.env.device ) self.src_demo_datagen_info_pool.load_from_dataset_file(dataset_path, select_demo_keys=self.demo_keys) else: raise ValueError("Either src_demo_datagen_info_pool or dataset_path must be provided")
def __repr__(self): """ Pretty print this object. """ msg = str(self.__class__.__name__) msg += " (\n\tdataset_path={}\n\tdemo_keys={}\n)".format( self.dataset_path, self.demo_keys, ) return msg
[docs] def randomize_subtask_boundaries(self): """ Apply random offsets to sample subtask boundaries according to the task spec. Recall that each demonstration is segmented into a set of subtask segments, and the end index of each subtask can have a random offset. """ # initial subtask start and end indices - shape (N, S, 2) src_subtask_indices = np.array(self.src_demo_datagen_info_pool.subtask_indices) # for each subtask (except last one), sample all end offsets at once for each demonstration # add them to subtask end indices, and then set them as the start indices of next subtask too for i in range(src_subtask_indices.shape[1] - 1): end_offsets = np.random.randint( low=self.subtask_configs[i].subtask_term_offset_range[0], high=self.subtask_configs[i].subtask_term_offset_range[1] + 1, size=src_subtask_indices.shape[0], ) src_subtask_indices[:, i, 1] = src_subtask_indices[:, i, 1] + end_offsets # don't forget to set these as start indices for next subtask too src_subtask_indices[:, i + 1, 0] = src_subtask_indices[:, i, 1] # ensure non-empty subtasks assert np.all((src_subtask_indices[:, :, 1] - src_subtask_indices[:, :, 0]) > 0), "got empty subtasks!" # ensure subtask indices increase (both starts and ends) assert np.all( (src_subtask_indices[:, 1:, :] - src_subtask_indices[:, :-1, :]) > 0 ), "subtask indices do not strictly increase" # ensure subtasks are in order subtask_inds_flat = src_subtask_indices.reshape(src_subtask_indices.shape[0], -1) assert np.all((subtask_inds_flat[:, 1:] - subtask_inds_flat[:, :-1]) >= 0), "subtask indices not in order" return src_subtask_indices
[docs] def select_source_demo( self, eef_pose, object_pose, subtask_ind, src_subtask_inds, subtask_object_name, selection_strategy_name, selection_strategy_kwargs=None, ): """ Helper method to run source subtask segment selection. Args: eef_pose (np.array): current end effector pose object_pose (np.array): current object pose for this subtask subtask_ind (int): index of subtask src_subtask_inds (np.array): start and end indices for subtask segment in source demonstrations of shape (N, 2) subtask_object_name (str): name of reference object for this subtask selection_strategy_name (str): name of selection strategy selection_strategy_kwargs (dict): extra kwargs for running selection strategy Returns: selected_src_demo_ind (int): selected source demo index """ if subtask_object_name is None: # no reference object - only random selection is supported assert selection_strategy_name == "random" # We need to collect the datagen info objects over the timesteps for the subtask segment in each source # demo, so that it can be used by the selection strategy. src_subtask_datagen_infos = [] for i in range(len(self.src_demo_datagen_info_pool.datagen_infos)): # datagen info over all timesteps of the src trajectory src_ep_datagen_info = self.src_demo_datagen_info_pool.datagen_infos[i] # time indices for subtask subtask_start_ind = src_subtask_inds[i][0] subtask_end_ind = src_subtask_inds[i][1] # get subtask segment using indices src_subtask_datagen_infos.append( DatagenInfo( eef_pose=src_ep_datagen_info.eef_pose[subtask_start_ind:subtask_end_ind], # only include object pose for relevant object in subtask object_poses=( { subtask_object_name: src_ep_datagen_info.object_poses[subtask_object_name][ subtask_start_ind:subtask_end_ind ] } if (subtask_object_name is not None) else None ), # subtask termination signal is unused subtask_term_signals=None, target_eef_pose=src_ep_datagen_info.target_eef_pose[subtask_start_ind:subtask_end_ind], gripper_action=src_ep_datagen_info.gripper_action[subtask_start_ind:subtask_end_ind], ) ) # make selection strategy object selection_strategy_obj = make_selection_strategy(selection_strategy_name) # run selection if selection_strategy_kwargs is None: selection_strategy_kwargs = dict() selected_src_demo_ind = selection_strategy_obj.select_source_demo( eef_pose=eef_pose, object_pose=object_pose, src_subtask_datagen_infos=src_subtask_datagen_infos, **selection_strategy_kwargs, ) return selected_src_demo_ind
[docs] async def generate( self, env_id, success_term, env_action_queue: asyncio.Queue | None = None, select_src_per_subtask=False, transform_first_robot_pose=False, interpolate_from_last_target_pose=True, pause_subtask=False, export_demo=True, ): """ Attempt to generate a new demonstration. Args: env_id (int): environment ID success_term (TerminationTermCfg): success function to check if the task is successful env_action_queue (asyncio.Queue): queue to store actions for each environment select_src_per_subtask (bool): if True, select a different source demonstration for each subtask during data generation, else keep the same one for the entire episode transform_first_robot_pose (bool): if True, each subtask segment will consist of the first robot pose and the target poses instead of just the target poses. Can sometimes help improve data generation quality as the interpolation segment will interpolate to where the robot started in the source segment instead of the first target pose. Note that the first subtask segment of each episode will always include the first robot pose, regardless of this argument. interpolate_from_last_target_pose (bool): if True, each interpolation segment will start from the last target pose in the previous subtask segment, instead of the current robot pose. Can sometimes improve data generation quality. pause_subtask (bool): if True, pause after every subtask during generation, for debugging. Returns: results (dict): dictionary with the following items: initial_state (dict): initial simulator state for the executed trajectory states (list): simulator state at each timestep observations (list): observation dictionary at each timestep datagen_infos (list): datagen_info at each timestep actions (np.array): action executed at each timestep success (bool): whether the trajectory successfully solved the task or not src_demo_inds (list): list of selected source demonstration indices for each subtask src_demo_labels (np.array): same as @src_demo_inds, but repeated to have a label for each timestep of the trajectory """ eef_names = list(self.env_cfg.subtask_configs.keys()) eef_name = eef_names[0] # reset the env to create a new task demo instance env_id_tensor = torch.tensor([env_id], dtype=torch.int64, device=self.env.device) self.env.recorder_manager.reset(env_ids=env_id_tensor) self.env.reset(env_ids=env_id_tensor) new_initial_state = self.env.scene.get_state(is_relative=True) # some state variables used during generation selected_src_demo_ind = None prev_executed_traj = None # save generated data in these variables generated_states = [] generated_obs = [] generated_actions = [] generated_success = False generated_src_demo_inds = [] # store selected src demo ind for each subtask in each trajectory generated_src_demo_labels = ( [] ) # like @generated_src_demo_inds, but padded to align with size of @generated_actions prev_src_demo_datagen_info_pool_size = 0 for subtask_ind in range(len(self.subtask_configs)): # some things only happen on first subtask is_first_subtask = subtask_ind == 0 # name of object for this subtask subtask_object_name = self.subtask_configs[subtask_ind].object_ref # corresponding current object pose cur_object_pose = ( self.env.get_object_poses(env_ids=[env_id])[subtask_object_name][0] if (subtask_object_name is not None) else None ) async with self.src_demo_datagen_info_pool.asyncio_lock: if len(self.src_demo_datagen_info_pool.datagen_infos) > prev_src_demo_datagen_info_pool_size: # src_demo_datagen_info_pool at this point may be updated with new demos, # so we need to updaet subtask boundaries again all_subtask_inds = ( self.randomize_subtask_boundaries() ) # shape [N, S, 2], last dim is start and end action lengths prev_src_demo_datagen_info_pool_size = len(self.src_demo_datagen_info_pool.datagen_infos) # We need source demonstration selection for the first subtask (always), and possibly for # other subtasks if @select_src_per_subtask is set. need_source_demo_selection = is_first_subtask or select_src_per_subtask # Run source demo selection or use selected demo from previous iteration if need_source_demo_selection: selected_src_demo_ind = self.select_source_demo( eef_pose=self.env.get_robot_eef_pose(eef_name, env_ids=[env_id])[0], object_pose=cur_object_pose, subtask_ind=subtask_ind, src_subtask_inds=all_subtask_inds[:, subtask_ind], subtask_object_name=subtask_object_name, selection_strategy_name=self.subtask_configs[subtask_ind].selection_strategy, selection_strategy_kwargs=self.subtask_configs[subtask_ind].selection_strategy_kwargs, ) assert selected_src_demo_ind is not None # selected subtask segment time indices selected_src_subtask_inds = all_subtask_inds[selected_src_demo_ind, subtask_ind] # get subtask segment, consisting of the sequence of robot eef poses, target poses, gripper actions src_ep_datagen_info = self.src_demo_datagen_info_pool.datagen_infos[selected_src_demo_ind] src_subtask_eef_poses = src_ep_datagen_info.eef_pose[ selected_src_subtask_inds[0] : selected_src_subtask_inds[1] ] src_subtask_target_poses = src_ep_datagen_info.target_eef_pose[ selected_src_subtask_inds[0] : selected_src_subtask_inds[1] ] src_subtask_gripper_actions = src_ep_datagen_info.gripper_action[ selected_src_subtask_inds[0] : selected_src_subtask_inds[1] ] # get reference object pose from source demo src_subtask_object_pose = ( src_ep_datagen_info.object_poses[subtask_object_name][selected_src_subtask_inds[0]] if (subtask_object_name is not None) else None ) if is_first_subtask or transform_first_robot_pose: # Source segment consists of first robot eef pose and the target poses. src_eef_poses = torch.cat([src_subtask_eef_poses[0:1], src_subtask_target_poses], dim=0) else: # Source segment consists of just the target poses. src_eef_poses = src_subtask_target_poses.clone() # account for extra timestep added to @src_eef_poses src_subtask_gripper_actions = torch.cat( [src_subtask_gripper_actions[0:1], src_subtask_gripper_actions], dim=0 ) # Transform source demonstration segment using relevant object pose. if subtask_object_name is not None: transformed_eef_poses = PoseUtils.transform_poses_from_frame_A_to_frame_B( src_poses=src_eef_poses, frame_A=cur_object_pose, frame_B=src_subtask_object_pose, ) else: # skip transformation if no reference object is provided transformed_eef_poses = src_eef_poses # We will construct a WaypointTrajectory instance to keep track of robot control targets # that will be executed and then execute it. traj_to_execute = WaypointTrajectory() if interpolate_from_last_target_pose and (not is_first_subtask): # Interpolation segment will start from last target pose (which may not have been achieved). assert prev_executed_traj is not None last_waypoint = prev_executed_traj.last_waypoint init_sequence = WaypointSequence(sequence=[last_waypoint]) else: # Interpolation segment will start from current robot eef pose. init_sequence = WaypointSequence.from_poses( eef_names=eef_names, poses=self.env.get_robot_eef_pose(eef_name, env_ids=[env_id])[0][None], gripper_actions=src_subtask_gripper_actions[0:1], action_noise=self.subtask_configs[subtask_ind].action_noise, ) traj_to_execute.add_waypoint_sequence(init_sequence) # Construct trajectory for the transformed segment. transformed_seq = WaypointSequence.from_poses( eef_names=eef_names, poses=transformed_eef_poses, gripper_actions=src_subtask_gripper_actions, action_noise=self.subtask_configs[subtask_ind].action_noise, ) transformed_traj = WaypointTrajectory() transformed_traj.add_waypoint_sequence(transformed_seq) # Merge this trajectory into our trajectory using linear interpolation. # Interpolation will happen from the initial pose (@init_sequence) to the first element of @transformed_seq. traj_to_execute.merge( transformed_traj, eef_names=eef_names, num_steps_interp=self.subtask_configs[subtask_ind].num_interpolation_steps, num_steps_fixed=self.subtask_configs[subtask_ind].num_fixed_steps, action_noise=( float(self.subtask_configs[subtask_ind].apply_noise_during_interpolation) * self.subtask_configs[subtask_ind].action_noise ), ) # We initialized @traj_to_execute with a pose to allow @merge to handle linear interpolation # for us. However, we can safely discard that first waypoint now, and just start by executing # the rest of the trajectory (interpolation segment and transformed subtask segment). traj_to_execute.pop_first() # Execute the trajectory and collect data. exec_results = await traj_to_execute.execute( env=self.env, env_id=env_id, env_action_queue=env_action_queue, success_term=success_term ) # check that trajectory is non-empty if len(exec_results["states"]) > 0: generated_states += exec_results["states"] generated_obs += exec_results["observations"] generated_actions.append(exec_results["actions"]) generated_success = generated_success or exec_results["success"] generated_src_demo_inds.append(selected_src_demo_ind) generated_src_demo_labels.append( selected_src_demo_ind * torch.ones( (exec_results["actions"].shape[0], 1), dtype=torch.int, device=exec_results["actions"].device ) ) # remember last trajectory prev_executed_traj = traj_to_execute if pause_subtask: input(f"Pausing after subtask {subtask_ind} execution. Press any key to continue...") # merge numpy arrays if len(generated_actions) > 0: generated_actions = torch.cat(generated_actions, dim=0) generated_src_demo_labels = torch.cat(generated_src_demo_labels, dim=0) # set success to the recorded episode data and export to file self.env.recorder_manager.set_success_to_episodes( env_id_tensor, torch.tensor([[generated_success]], dtype=torch.bool, device=self.env.device) ) if export_demo: self.env.recorder_manager.export_episodes(env_id_tensor) results = dict( initial_state=new_initial_state, states=generated_states, observations=generated_obs, actions=generated_actions, success=generated_success, src_demo_inds=generated_src_demo_inds, src_demo_labels=generated_src_demo_labels, ) return results