in seamseg/utils/parallel/packed_sequence.py [0:0]
def __init__(self, *args):
if len(args) == 1 and isinstance(args[0], list):
tensors = args[0]
else:
tensors = args
# Check if all input are tensors of the same type and device
for tensor in tensors:
if tensor is not None and not isinstance(tensor, torch.Tensor):
raise TypeError("All args must be tensors")
if not _all_same([tensor.dtype for tensor in tensors if tensor is not None]):
raise TypeError("All tensors must have the same type")
if not _all_same([tensor.device for tensor in tensors if tensor is not None]):
raise TypeError("All tensors must reside on the same device")
self._tensors = tensors
# Check useful properties of the sequence
self._compatible = _all_same([tensor.shape[1:] for tensor in self._tensors if tensor is not None])
self._all_none = all([tensor is None for tensor in self._tensors])