Source code for isaaclab.sensors.ray_caster.multi_mesh_ray_caster_camera

# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
from collections.abc import Sequence
from typing import TYPE_CHECKING

import isaaclab.utils.math as math_utils
from isaaclab.utils.warp import raycast_dynamic_meshes

from .multi_mesh_ray_caster import MultiMeshRayCaster
from .multi_mesh_ray_caster_camera_data import MultiMeshRayCasterCameraData
from .prim_utils import obtain_world_pose_from_view
from .ray_caster_camera import RayCasterCamera

if TYPE_CHECKING:
    from .multi_mesh_ray_caster_camera_cfg import MultiMeshRayCasterCameraCfg


[docs]class MultiMeshRayCasterCamera(RayCasterCamera, MultiMeshRayCaster): """A multi-mesh ray-casting camera sensor. The ray-caster camera uses a set of rays to get the distances to meshes in the scene. The rays are defined in the sensor's local coordinate frame. The sensor has the same interface as the :class:`isaaclab.sensors.Camera` that implements the camera class through USD camera prims. However, this class provides a faster image generation. The sensor converts meshes from the list of primitive paths provided in the configuration to Warp meshes. The camera then ray-casts against these Warp meshes only. Currently, only the following annotators are supported: - ``"distance_to_camera"``: An image containing the distance to camera optical center. - ``"distance_to_image_plane"``: An image containing distances of 3D points from camera plane along camera's z-axis. - ``"normals"``: An image containing the local surface normal vectors at each pixel. """ cfg: MultiMeshRayCasterCameraCfg """The configuration parameters."""
[docs] def __init__(self, cfg: MultiMeshRayCasterCameraCfg): """Initializes the camera object. Args: cfg: The configuration parameters. Raises: ValueError: If the provided data types are not supported by the ray-caster camera. """ self._check_supported_data_types(cfg) # initialize base class MultiMeshRayCaster.__init__(self, cfg) # create empty variables for storing output data self._data = MultiMeshRayCasterCameraData()
def __str__(self) -> str: """Returns: A string containing information about the instance.""" return ( f"Multi-Mesh Ray-Caster-Camera @ '{self.cfg.prim_path}': \n" f"\tview type : {self._view.__class__}\n" f"\tupdate period (s) : {self.cfg.update_period}\n" f"\tnumber of meshes : {len(MultiMeshRayCaster.meshes)}\n" f"\tnumber of sensors : {self._view.count}\n" f"\tnumber of rays/sensor: {self.num_rays}\n" f"\ttotal number of rays : {self.num_rays * self._view.count}\n" f"\timage shape : {self.image_shape}" ) """ Implementation. """ def _initialize_warp_meshes(self): MultiMeshRayCaster._initialize_warp_meshes(self) def _create_buffers(self): super()._create_buffers() self._data.image_mesh_ids = torch.zeros( self._num_envs, *self.image_shape, 1, device=self.device, dtype=torch.int16 ) def _initialize_rays_impl(self): # Create all indices buffer self._ALL_INDICES = torch.arange(self._view.count, device=self._device, dtype=torch.long) # Create frame count buffer self._frame = torch.zeros(self._view.count, device=self._device, dtype=torch.long) # create buffers self._create_buffers() # compute intrinsic matrices self._compute_intrinsic_matrices() # compute ray stars and directions self.ray_starts, self.ray_directions = self.cfg.pattern_cfg.func( self.cfg.pattern_cfg, self._data.intrinsic_matrices, self._device ) self.num_rays = self.ray_directions.shape[1] # create buffer to store ray hits self.ray_hits_w = torch.zeros(self._view.count, self.num_rays, 3, device=self._device) # set offsets quat_w = math_utils.convert_camera_frame_orientation_convention( torch.tensor([self.cfg.offset.rot], device=self._device), origin=self.cfg.offset.convention, target="world" ) self._offset_quat = quat_w.repeat(self._view.count, 1) self._offset_pos = torch.tensor(list(self.cfg.offset.pos), device=self._device).repeat(self._view.count, 1) self._data.quat_w = torch.zeros(self._view.count, 4, device=self.device) self._data.pos_w = torch.zeros(self._view.count, 3, device=self.device) self._ray_starts_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) self._ray_directions_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) def _update_ray_infos(self, env_ids: Sequence[int]): """Updates the ray information buffers.""" # compute poses from current view pos_w, quat_w = obtain_world_pose_from_view(self._view, env_ids) pos_w, quat_w = math_utils.combine_frame_transforms( pos_w, quat_w, self._offset_pos[env_ids], self._offset_quat[env_ids] ) # update the data self._data.pos_w[env_ids] = pos_w self._data.quat_w_world[env_ids] = quat_w self._data.quat_w_ros[env_ids] = quat_w # note: full orientation is considered ray_starts_w = math_utils.quat_apply(quat_w.repeat(1, self.num_rays), self.ray_starts[env_ids]) ray_starts_w += pos_w.unsqueeze(1) ray_directions_w = math_utils.quat_apply(quat_w.repeat(1, self.num_rays), self.ray_directions[env_ids]) self._ray_starts_w[env_ids] = ray_starts_w self._ray_directions_w[env_ids] = ray_directions_w def _update_buffers_impl(self, env_ids: Sequence[int] | torch.Tensor | None): """Fills the buffers of the sensor data.""" self._update_ray_infos(env_ids) # increment frame count if env_ids is None: env_ids = torch.arange(self._num_envs, device=self.device) elif not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor(env_ids, device=self.device) self._frame[env_ids] += 1 # Update the mesh positions and rotations mesh_idx = 0 for view, target_cfg in zip(self._mesh_views, self._raycast_targets_cfg): if not target_cfg.track_mesh_transforms: mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] continue # update position of the target meshes pos_w, ori_w = obtain_world_pose_from_view(view, None) pos_w = pos_w.squeeze(0) if len(pos_w.shape) == 3 else pos_w ori_w = ori_w.squeeze(0) if len(ori_w.shape) == 3 else ori_w if target_cfg.prim_expr in MultiMeshRayCaster.mesh_offsets: pos_offset, ori_offset = MultiMeshRayCaster.mesh_offsets[target_cfg.prim_expr] pos_w -= pos_offset ori_w = math_utils.quat_mul(ori_offset.expand(ori_w.shape[0], -1), ori_w) count = view.count if count != 1: # Mesh is not global, i.e. we have different meshes for each env count = count // self._num_envs pos_w = pos_w.view(self._num_envs, count, 3) ori_w = ori_w.view(self._num_envs, count, 4) self._mesh_positions_w[:, mesh_idx : mesh_idx + count] = pos_w self._mesh_orientations_w[:, mesh_idx : mesh_idx + count] = ori_w mesh_idx += count # ray cast and store the hits self.ray_hits_w[env_ids], ray_depth, ray_normal, _, ray_mesh_ids = raycast_dynamic_meshes( self._ray_starts_w[env_ids], self._ray_directions_w[env_ids], mesh_ids_wp=self._mesh_ids_wp, # list with shape num_envs x num_meshes_per_env max_dist=self.cfg.max_distance, mesh_positions_w=self._mesh_positions_w[env_ids], mesh_orientations_w=self._mesh_orientations_w[env_ids], return_distance=any( [name in self.cfg.data_types for name in ["distance_to_image_plane", "distance_to_camera"]] ), return_normal="normals" in self.cfg.data_types, return_mesh_id=self.cfg.update_mesh_ids, ) # update output buffers if "distance_to_image_plane" in self.cfg.data_types: # note: data is in camera frame so we only take the first component (z-axis of camera frame) distance_to_image_plane = ( math_utils.quat_apply( math_utils.quat_inv(self._data.quat_w_world[env_ids]).repeat(1, self.num_rays), (ray_depth[:, :, None] * self._ray_directions_w[env_ids]), ) )[:, :, 0] # apply the maximum distance after the transformation if self.cfg.depth_clipping_behavior == "max": distance_to_image_plane = torch.clip(distance_to_image_plane, max=self.cfg.max_distance) distance_to_image_plane[torch.isnan(distance_to_image_plane)] = self.cfg.max_distance elif self.cfg.depth_clipping_behavior == "zero": distance_to_image_plane[distance_to_image_plane > self.cfg.max_distance] = 0.0 distance_to_image_plane[torch.isnan(distance_to_image_plane)] = 0.0 self._data.output["distance_to_image_plane"][env_ids] = distance_to_image_plane.view( -1, *self.image_shape, 1 ) if "distance_to_camera" in self.cfg.data_types: if self.cfg.depth_clipping_behavior == "max": ray_depth = torch.clip(ray_depth, max=self.cfg.max_distance) elif self.cfg.depth_clipping_behavior == "zero": ray_depth[ray_depth > self.cfg.max_distance] = 0.0 self._data.output["distance_to_camera"][env_ids] = ray_depth.view(-1, *self.image_shape, 1) if "normals" in self.cfg.data_types: self._data.output["normals"][env_ids] = ray_normal.view(-1, *self.image_shape, 3) if self.cfg.update_mesh_ids: self._data.image_mesh_ids[env_ids] = ray_mesh_ids.view(-1, *self.image_shape, 1)