in dualpipe/utils.py [0:0]
def cat_tensor(x, dim):
if (isinstance(x, tuple) or isinstance(x, list)):
if len(x) == 1:
return x[0]
elif x[0] is None:
assert all(y is None for y in x)
return None
return torch.cat(x, dim=dim)