# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import numpy as np
import re
import torch
from collections.abc import Sequence
from typing import TYPE_CHECKING
import omni.log
import omni.physics.tensors.impl.api as physx
import warp as wp
from omni.isaac.core.prims import XFormPrimView
from pxr import UsdGeom, UsdPhysics
import omni.isaac.lab.sim as sim_utils
from omni.isaac.lab.markers import VisualizationMarkers
from omni.isaac.lab.terrains.trimesh.utils import make_plane
from omni.isaac.lab.utils.math import convert_quat, quat_apply, quat_apply_yaw
from omni.isaac.lab.utils.warp import convert_to_warp_mesh, raycast_mesh
from ..sensor_base import SensorBase
from .ray_caster_data import RayCasterData
if TYPE_CHECKING:
from .ray_caster_cfg import RayCasterCfg
[docs]class RayCaster(SensorBase):
"""A ray-casting sensor.
The ray-caster uses a set of rays to detect collisions with meshes in the scene. The rays are
defined in the sensor's local coordinate frame. The sensor can be configured to ray-cast against
a set of meshes with a given ray pattern.
The meshes are parsed from the list of primitive paths provided in the configuration. These are then
converted to warp meshes and stored in the `warp_meshes` list. The ray-caster then ray-casts against
these warp meshes using the ray pattern provided in the configuration.
.. note::
Currently, only static meshes are supported. Extending the warp mesh to support dynamic meshes
is a work in progress.
"""
cfg: RayCasterCfg
"""The configuration parameters."""
[docs] def __init__(self, cfg: RayCasterCfg):
"""Initializes the ray-caster object.
Args:
cfg: The configuration parameters.
"""
# check if sensor path is valid
# note: currently we do not handle environment indices if there is a regex pattern in the leaf
# For example, if the prim path is "/World/Sensor_[1,2]".
sensor_path = cfg.prim_path.split("/")[-1]
sensor_path_is_regex = re.match(r"^[a-zA-Z0-9/_]+$", sensor_path) is None
if sensor_path_is_regex:
raise RuntimeError(
f"Invalid prim path for the ray-caster sensor: {self.cfg.prim_path}."
"\n\tHint: Please ensure that the prim path does not contain any regex patterns in the leaf."
)
# Initialize base class
super().__init__(cfg)
# Create empty variables for storing output data
self._data = RayCasterData()
# the warp meshes used for raycasting.
self.meshes: dict[str, wp.Mesh] = {}
def __str__(self) -> str:
"""Returns: A string containing information about the instance."""
return (
f"Ray-caster @ '{self.cfg.prim_path}': \n"
f"\tview type : {self._view.__class__}\n"
f"\tupdate period (s) : {self.cfg.update_period}\n"
f"\tnumber of meshes : {len(self.meshes)}\n"
f"\tnumber of sensors : {self._view.count}\n"
f"\tnumber of rays/sensor: {self.num_rays}\n"
f"\ttotal number of rays : {self.num_rays * self._view.count}"
)
"""
Properties
"""
@property
def num_instances(self) -> int:
return self._view.count
@property
def data(self) -> RayCasterData:
# update sensors if needed
self._update_outdated_buffers()
# return the data
return self._data
"""
Operations.
"""
[docs] def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters
super().reset(env_ids)
# resolve None
if env_ids is None:
env_ids = slice(None)
# resample the drift
self.drift[env_ids].uniform_(*self.cfg.drift_range)
"""
Implementation.
"""
def _initialize_impl(self):
super()._initialize_impl()
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# check if the prim at path is an articulated or rigid prim
# we do this since for physics-based view classes we can access their data directly
# otherwise we need to use the xform view class which is slower
found_supported_prim_class = False
prim = sim_utils.find_first_matching_prim(self.cfg.prim_path)
if prim is None:
raise RuntimeError(f"Failed to find a prim at path expression: {self.cfg.prim_path}")
# create view based on the type of prim
if prim.HasAPI(UsdPhysics.ArticulationRootAPI):
self._view = self._physics_sim_view.create_articulation_view(self.cfg.prim_path.replace(".*", "*"))
found_supported_prim_class = True
elif prim.HasAPI(UsdPhysics.RigidBodyAPI):
self._view = self._physics_sim_view.create_rigid_body_view(self.cfg.prim_path.replace(".*", "*"))
found_supported_prim_class = True
else:
self._view = XFormPrimView(self.cfg.prim_path, reset_xform_properties=False)
found_supported_prim_class = True
omni.log.warn(f"The prim at path {prim.GetPath().pathString} is not a physics prim! Using XFormPrimView.")
# check if prim view class is found
if not found_supported_prim_class:
raise RuntimeError(f"Failed to find a valid prim view class for the prim paths: {self.cfg.prim_path}")
# load the meshes by parsing the stage
self._initialize_warp_meshes()
# initialize the ray start and directions
self._initialize_rays_impl()
def _initialize_warp_meshes(self):
# check number of mesh prims provided
if len(self.cfg.mesh_prim_paths) != 1:
raise NotImplementedError(
f"RayCaster currently only supports one mesh prim. Received: {len(self.cfg.mesh_prim_paths)}"
)
# read prims to ray-cast
for mesh_prim_path in self.cfg.mesh_prim_paths:
# check if the prim is a plane - handle PhysX plane as a special case
# if a plane exists then we need to create an infinite mesh that is a plane
mesh_prim = sim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda prim: prim.GetTypeName() == "Plane"
)
# if we did not find a plane then we need to read the mesh
if mesh_prim is None:
# obtain the mesh prim
mesh_prim = sim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda prim: prim.GetTypeName() == "Mesh"
)
# check if valid
if mesh_prim is None or not mesh_prim.IsValid():
raise RuntimeError(f"Invalid mesh prim path: {mesh_prim_path}")
# cast into UsdGeomMesh
mesh_prim = UsdGeom.Mesh(mesh_prim)
# read the vertices and faces
points = np.asarray(mesh_prim.GetPointsAttr().Get())
transform_matrix = np.array(omni.usd.get_world_transform_matrix(mesh_prim)).T
points = np.matmul(points, transform_matrix[:3, :3].T)
points += transform_matrix[:3, 3]
indices = np.asarray(mesh_prim.GetFaceVertexIndicesAttr().Get())
wp_mesh = convert_to_warp_mesh(points, indices, device=self.device)
# print info
omni.log.info(
f"Read mesh prim: {mesh_prim.GetPath()} with {len(points)} vertices and {len(indices)} faces."
)
else:
mesh = make_plane(size=(2e6, 2e6), height=0.0, center_zero=True)
wp_mesh = convert_to_warp_mesh(mesh.vertices, mesh.faces, device=self.device)
# print info
omni.log.info(f"Created infinite plane mesh prim: {mesh_prim.GetPath()}.")
# add the warp mesh to the list
self.meshes[mesh_prim_path] = wp_mesh
# throw an error if no meshes are found
if all([mesh_prim_path not in self.meshes for mesh_prim_path in self.cfg.mesh_prim_paths]):
raise RuntimeError(
f"No meshes found for ray-casting! Please check the mesh prim paths: {self.cfg.mesh_prim_paths}"
)
def _initialize_rays_impl(self):
# compute ray stars and directions
self.ray_starts, self.ray_directions = self.cfg.pattern_cfg.func(self.cfg.pattern_cfg, self._device)
self.num_rays = len(self.ray_directions)
# apply offset transformation to the rays
offset_pos = torch.tensor(list(self.cfg.offset.pos), device=self._device)
offset_quat = torch.tensor(list(self.cfg.offset.rot), device=self._device)
self.ray_directions = quat_apply(offset_quat.repeat(len(self.ray_directions), 1), self.ray_directions)
self.ray_starts += offset_pos
# repeat the rays for each sensor
self.ray_starts = self.ray_starts.repeat(self._view.count, 1, 1)
self.ray_directions = self.ray_directions.repeat(self._view.count, 1, 1)
# prepare drift
self.drift = torch.zeros(self._view.count, 3, device=self.device)
# fill the data buffer
self._data.pos_w = torch.zeros(self._view.count, 3, device=self._device)
self._data.quat_w = torch.zeros(self._view.count, 4, device=self._device)
self._data.ray_hits_w = torch.zeros(self._view.count, self.num_rays, 3, device=self._device)
def _update_buffers_impl(self, env_ids: Sequence[int]):
"""Fills the buffers of the sensor data."""
# obtain the poses of the sensors
if isinstance(self._view, XFormPrimView):
pos_w, quat_w = self._view.get_world_poses(env_ids)
elif isinstance(self._view, physx.ArticulationView):
pos_w, quat_w = self._view.get_root_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = convert_quat(quat_w, to="wxyz")
elif isinstance(self._view, physx.RigidBodyView):
pos_w, quat_w = self._view.get_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = convert_quat(quat_w, to="wxyz")
else:
raise RuntimeError(f"Unsupported view type: {type(self._view)}")
# note: we clone here because we are read-only operations
pos_w = pos_w.clone()
quat_w = quat_w.clone()
# apply drift
pos_w += self.drift[env_ids]
# store the poses
self._data.pos_w[env_ids] = pos_w
self._data.quat_w[env_ids] = quat_w
# ray cast based on the sensor poses
if self.cfg.attach_yaw_only:
# only yaw orientation is considered and directions are not rotated
ray_starts_w = quat_apply_yaw(quat_w.repeat(1, self.num_rays), self.ray_starts[env_ids])
ray_starts_w += pos_w.unsqueeze(1)
ray_directions_w = self.ray_directions[env_ids]
else:
# full orientation is considered
ray_starts_w = quat_apply(quat_w.repeat(1, self.num_rays), self.ray_starts[env_ids])
ray_starts_w += pos_w.unsqueeze(1)
ray_directions_w = quat_apply(quat_w.repeat(1, self.num_rays), self.ray_directions[env_ids])
# ray cast and store the hits
# TODO: Make this work for multiple meshes?
self._data.ray_hits_w[env_ids] = raycast_mesh(
ray_starts_w,
ray_directions_w,
max_dist=self.cfg.max_distance,
mesh=self.meshes[self.cfg.mesh_prim_paths[0]],
)[0]
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
# note: parent only deals with callbacks. not their visibility
if debug_vis:
if not hasattr(self, "ray_visualizer"):
self.ray_visualizer = VisualizationMarkers(self.cfg.visualizer_cfg)
# set their visibility to true
self.ray_visualizer.set_visibility(True)
else:
if hasattr(self, "ray_visualizer"):
self.ray_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# show ray hit positions
self.ray_visualizer.visualize(self._data.ray_hits_w.view(-1, 3))
"""
Internal simulation callbacks.
"""
def _invalidate_initialize_callback(self, event):
"""Invalidates the scene elements."""
# call parent
super()._invalidate_initialize_callback(event)
# set all existing views to None to invalidate them
self._physics_sim_view = None
self._view = None