in rlmeta/core/replay_buffer.py [0:0]
def _update_priority(self, index: Union[int, Tensor],
priority: Union[float, Tensor]) -> None:
priority += self.eps
if isinstance(priority, float):
self._max_priority = max(self._max_priority, priority)
else:
self._max_priority = max(self._max_priority, priority.max().item())
priority = priority**self.alpha
if isinstance(priority, np.ndarray):
priority = priority.astype(self.priority_type)
elif isinstance(priority, torch.Tensor):
priority = priority.to(
data_utils.numpy_dtype_to_torch(self.priority_type))
self._sum_tree[index] = priority
self._min_tree[index] = priority