# Copyright (c) 2022-2025, The Isaac Lab Project Developers.# All rights reserved.## SPDX-License-Identifier: BSD-3-Clauseimportcopyimportosimporttorch
[docs]defexport_policy_as_jit(policy:object,normalizer:object|None,path:str,filename="policy.pt"):"""Export policy into a Torch JIT file. Args: policy: The policy torch module. normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. filename: The name of exported JIT file. Defaults to "policy.pt". """policy_exporter=_TorchPolicyExporter(policy,normalizer)policy_exporter.export(path,filename)
[docs]defexport_policy_as_onnx(policy:object,path:str,normalizer:object|None=None,filename="policy.onnx",verbose=False):"""Export policy into a Torch ONNX file. Args: policy: The policy torch module. normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. filename: The name of exported ONNX file. Defaults to "policy.onnx". verbose: Whether to print the model summary. Defaults to False. """ifnotos.path.exists(path):os.makedirs(path,exist_ok=True)policy_exporter=_OnnxPolicyExporter(policy,normalizer,verbose)policy_exporter.export(path,filename)
"""Helper Classes - Private."""class_TorchPolicyExporter(torch.nn.Module):"""Exporter of actor-critic into JIT file."""def__init__(self,policy,normalizer=None):super().__init__()self.is_recurrent=policy.is_recurrent# copy policy parametersifhasattr(policy,"actor"):self.actor=copy.deepcopy(policy.actor)ifself.is_recurrent:self.rnn=copy.deepcopy(policy.memory_a.rnn)elifhasattr(policy,"student"):self.actor=copy.deepcopy(policy.student)ifself.is_recurrent:self.rnn=copy.deepcopy(policy.memory_s.rnn)else:raiseValueError("Policy does not have an actor/student module.")# set up recurrent networkifself.is_recurrent:self.rnn.cpu()self.register_buffer("hidden_state",torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size))self.register_buffer("cell_state",torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size))self.forward=self.forward_lstmself.reset=self.reset_memory# copy normalizer if existsifnormalizer:self.normalizer=copy.deepcopy(normalizer)else:self.normalizer=torch.nn.Identity()defforward_lstm(self,x):x=self.normalizer(x)x,(h,c)=self.rnn(x.unsqueeze(0),(self.hidden_state,self.cell_state))self.hidden_state[:]=hself.cell_state[:]=cx=x.squeeze(0)returnself.actor(x)defforward(self,x):returnself.actor(self.normalizer(x))@torch.jit.exportdefreset(self):passdefreset_memory(self):self.hidden_state[:]=0.0self.cell_state[:]=0.0defexport(self,path,filename):os.makedirs(path,exist_ok=True)path=os.path.join(path,filename)self.to("cpu")traced_script_module=torch.jit.script(self)traced_script_module.save(path)class_OnnxPolicyExporter(torch.nn.Module):"""Exporter of actor-critic into ONNX file."""def__init__(self,policy,normalizer=None,verbose=False):super().__init__()self.verbose=verboseself.is_recurrent=policy.is_recurrent# copy policy parametersifhasattr(policy,"actor"):self.actor=copy.deepcopy(policy.actor)ifself.is_recurrent:self.rnn=copy.deepcopy(policy.memory_a.rnn)elifhasattr(policy,"student"):self.actor=copy.deepcopy(policy.student)ifself.is_recurrent:self.rnn=copy.deepcopy(policy.memory_s.rnn)else:raiseValueError("Policy does not have an actor/student module.")# set up recurrent networkifself.is_recurrent:self.rnn.cpu()self.forward=self.forward_lstm# copy normalizer if existsifnormalizer:self.normalizer=copy.deepcopy(normalizer)else:self.normalizer=torch.nn.Identity()defforward_lstm(self,x_in,h_in,c_in):x_in=self.normalizer(x_in)x,(h,c)=self.rnn(x_in.unsqueeze(0),(h_in,c_in))x=x.squeeze(0)returnself.actor(x),h,cdefforward(self,x):returnself.actor(self.normalizer(x))defexport(self,path,filename):self.to("cpu")ifself.is_recurrent:obs=torch.zeros(1,self.rnn.input_size)h_in=torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size)c_in=torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size)actions,h_out,c_out=self(obs,h_in,c_in)torch.onnx.export(self,(obs,h_in,c_in),os.path.join(path,filename),export_params=True,opset_version=11,verbose=self.verbose,input_names=["obs","h_in","c_in"],output_names=["actions","h_out","c_out"],dynamic_axes={},)else:obs=torch.zeros(1,self.actor[0].in_features)torch.onnx.export(self,obs,os.path.join(path,filename),export_params=True,opset_version=11,verbose=self.verbose,input_names=["obs"],output_names=["actions"],dynamic_axes={},)