in differentiable_robot_model/robot_model.py [0:0]
def tensor_check(function):
"""
A decorator for checking the device of input tensors
"""
@dataclass
class BatchInfo:
shape: torch.Size = torch.Size([])
init: bool = False
def preprocess(arg, obj, batch_info):
if type(arg) is torch.Tensor:
# Check device
assert (
arg.device.type == obj._device.type
), f"Input argument of different device as module: {arg}"
# Check dimensions & convert to 2-dim tensors
assert arg.ndim in [1, 2], f"Input tensors must have ndim of 1 or 2."
if batch_info.init:
assert (
batch_info.shape == arg.shape[:-1]
), "Batch size mismatch between input tensors."
else:
batch_info.init = True
batch_info.shape = arg.shape[:-1]
if len(batch_info.shape) == 0:
return arg.unsqueeze(0)
return arg
def postprocess(arg, batch_info):
if type(arg) is torch.Tensor and batch_info.init and len(batch_info.shape) == 0:
return arg[0, ...]
return arg
def wrapper(self, *args, **kwargs):
batch_info = BatchInfo()
# Parse input
processed_args = [preprocess(arg, self, batch_info) for arg in args]
processed_kwargs = {
key: preprocess(kwargs[key], self, batch_info) for key in kwargs
}
# Perform function
ret = function(self, *processed_args, **processed_kwargs)
# Parse output
if type(ret) is torch.Tensor:
return postprocess(ret, batch_info)
elif type(ret) is tuple:
return tuple([postprocess(r, batch_info) for r in ret])
else:
return ret
return wrapper