Source code for isaaclab.utils.buffers.circular_buffer
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.# All rights reserved.## SPDX-License-Identifier: BSD-3-Clauseimporttorchfromcollections.abcimportSequence
[docs]classCircularBuffer:"""Circular buffer for storing a history of batched tensor data. This class implements a circular buffer for storing a history of batched tensor data. The buffer is initialized with a maximum length and a batch size. The data is stored in a circular fashion, and the data can be retrieved in a LIFO (Last-In-First-Out) fashion. The buffer is designed to be used in multi-environment settings, where each environment has its own data. The shape of the appended data is expected to be (batch_size, ...), where the first dimension is the batch dimension. Correspondingly, the shape of the ring buffer is (max_len, batch_size, ...). """
[docs]def__init__(self,max_len:int,batch_size:int,device:str):"""Initialize the circular buffer. Args: max_len: The maximum length of the circular buffer. The minimum allowed value is 1. batch_size: The batch dimension of the data. device: The device used for processing. Raises: ValueError: If the buffer size is less than one. """ifmax_len<1:raiseValueError(f"The buffer size should be greater than zero. However, it is set to {max_len}!")# set the parametersself._batch_size=batch_sizeself._device=deviceself._ALL_INDICES=torch.arange(batch_size,device=device)# max length tensor for comparisonsself._max_len=torch.full((batch_size,),max_len,dtype=torch.int,device=device)# number of data pushes passed since the last call to :meth:`reset`self._num_pushes=torch.zeros(batch_size,dtype=torch.long,device=device)# the pointer to the current head of the circular buffer (-1 means not initialized)self._pointer:int=-1# the actual buffer for data storage# note: this is initialized on the first call to :meth:`append`self._buffer:torch.Tensor=None# type: ignore
""" Properties. """@propertydefbatch_size(self)->int:"""The batch size of the ring buffer."""returnself._batch_size@propertydefdevice(self)->str:"""The device used for processing."""returnself._device@propertydefmax_length(self)->int:"""The maximum length of the ring buffer."""returnint(self._max_len[0].item())@propertydefcurrent_length(self)->torch.Tensor:"""The current length of the buffer. Shape is (batch_size,). Since the buffer is circular, the current length is the minimum of the number of pushes and the maximum length. """returntorch.minimum(self._num_pushes,self._max_len)@propertydefbuffer(self)->torch.Tensor:"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning. Returns: Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]]. """buf=self._buffer.clone()buf=torch.roll(buf,shifts=self.max_length-self._pointer-1,dims=0)returntorch.transpose(buf,dim0=0,dim1=1)""" Operations. """
[docs]defreset(self,batch_ids:Sequence[int]|None=None):"""Reset the circular buffer at the specified batch indices. Args: batch_ids: Elements to reset in the batch dimension. Default is None, which resets all the batch indices. """# resolve all indicesifbatch_idsisNone:batch_ids=slice(None)# reset the number of pushes for the specified batch indicesself._num_pushes[batch_ids]=0ifself._bufferisnotNone:# set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset.self._buffer[:,batch_ids,:]=0.0
[docs]defappend(self,data:torch.Tensor):"""Append the data to the circular buffer. Args: data: The data to append to the circular buffer. The first dimension should be the batch dimension. Shape is (batch_size, ...). Raises: ValueError: If the input data has a different batch size than the buffer. """# check the batch sizeifdata.shape[0]!=self.batch_size:raiseValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}")# at the first call, initialize the buffer sizeifself._bufferisNone:self._pointer=-1self._buffer=torch.empty((self.max_length,*data.shape),dtype=data.dtype,device=self._device)# move the head to the next slotself._pointer=(self._pointer+1)%self.max_length# add the new data to the last layerself._buffer[self._pointer]=data.to(self._device)# Check for batches with zero pushes and initialize all values in batch to first appendif0inself._num_pushes.tolist():fill_ids=[ifori,xinenumerate(self._num_pushes.tolist())ifx==0]self._num_pushes.tolist().index(0)if0inself._num_pushes.tolist()elseNoneself._buffer[:,fill_ids,:]=data.to(self._device)[fill_ids]# increment number of number of pushes for all batchesself._num_pushes+=1
def__getitem__(self,key:torch.Tensor)->torch.Tensor:"""Retrieve the data from the circular buffer in last-in-first-out (LIFO) fashion. If the requested index is larger than the number of pushes since the last call to :meth:`reset`, the oldest stored data is returned. Args: key: The index to retrieve from the circular buffer. The index should be less than the number of pushes since the last call to :meth:`reset`. Shape is (batch_size,). Returns: The data from the circular buffer. Shape is (batch_size, ...). Raises: ValueError: If the input key has a different batch size than the buffer. RuntimeError: If the buffer is empty. """# check the batch sizeiflen(key)!=self.batch_size:raiseValueError(f"The argument 'key' has length {key.shape[0]}, while expecting {self.batch_size}")# check if the buffer is emptyiftorch.any(self._num_pushes==0)orself._bufferisNone:raiseRuntimeError("Attempting to retrieve data on an empty circular buffer. Please append data first.")# admissible lagvalid_keys=torch.minimum(key,self._num_pushes-1)# the index in the circular buffer (pointer points to the last+1 index)index_in_buffer=torch.remainder(self._pointer-valid_keys,self.max_length)# return outputreturnself._buffer[index_in_buffer,self._ALL_INDICES]