in tools/utils.py [0:0]
def cat_dataclass(batch, tensor_collator):
elem = batch[0]
collated = {}
for f in dataclasses.fields(elem):
elem_f = getattr(elem, f.name)
if elem_f is None:
collated[f.name] = None
elif torch.is_tensor(elem_f):
collated[f.name] = tensor_collator([getattr(e, f.name) for e in batch])
elif dataclasses.is_dataclass(elem_f):
collated[f.name] = cat_dataclass(
[getattr(e, f.name) for e in batch], tensor_collator
)
elif isinstance(elem_f, collections.abc.Mapping):
collated[f.name] = {
k: tensor_collator([getattr(e, f.name)[k] for e in batch])
if elem_f[k] is not None
else None
for k in elem_f
}
else:
raise ValueError("Unsupported field type for concatenation")
return type(elem)(**collated)