# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import copy
import os
import torch
[docs]def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
"""Export policy into a Torch JIT file.
Args:
actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt".
"""
policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
policy_exporter.export(path, filename)
[docs]def export_policy_as_onnx(
actor_critic: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
):
"""Export policy into a Torch ONNX file.
Args:
actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported ONNX file. Defaults to "policy.onnx".
verbose: Whether to print the model summary. Defaults to False.
"""
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
policy_exporter.export(path, filename)
"""
Helper Classes - Private.
"""
class _TorchPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into JIT file."""
def __init__(self, actor_critic, normalizer=None):
super().__init__()
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
if self.is_recurrent:
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.forward = self.forward_lstm
self.reset = self.reset_memory
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()
def forward_lstm(self, x):
x = self.normalizer(x)
x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state))
self.hidden_state[:] = h
self.cell_state[:] = c
x = x.squeeze(0)
return self.actor(x)
def forward(self, x):
return self.actor(self.normalizer(x))
@torch.jit.export
def reset(self):
pass
def reset_memory(self):
self.hidden_state[:] = 0.0
self.cell_state[:] = 0.0
def export(self, path, filename):
os.makedirs(path, exist_ok=True)
path = os.path.join(path, filename)
self.to("cpu")
traced_script_module = torch.jit.script(self)
traced_script_module.save(path)
class _OnnxPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into ONNX file."""
def __init__(self, actor_critic, normalizer=None, verbose=False):
super().__init__()
self.verbose = verbose
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
if self.is_recurrent:
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.forward = self.forward_lstm
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()
def forward_lstm(self, x_in, h_in, c_in):
x_in = self.normalizer(x_in)
x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
x = x.squeeze(0)
return self.actor(x), h, c
def forward(self, x):
return self.actor(self.normalizer(x))
def export(self, path, filename):
self.to("cpu")
if self.is_recurrent:
obs = torch.zeros(1, self.rnn.input_size)
h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
actions, h_out, c_out = self(obs, h_in, c_in)
torch.onnx.export(
self,
(obs, h_in, c_in),
os.path.join(path, filename),
export_params=True,
opset_version=11,
verbose=self.verbose,
input_names=["obs", "h_in", "c_in"],
output_names=["actions", "h_out", "c_out"],
dynamic_axes={},
)
else:
obs = torch.zeros(1, self.actor[0].in_features)
torch.onnx.export(
self,
obs,
os.path.join(path, filename),
export_params=True,
opset_version=11,
verbose=self.verbose,
input_names=["obs"],
output_names=["actions"],
dynamic_axes={},
)