def cat_dataclass()

in tools/utils.py [0:0]


def cat_dataclass(batch, tensor_collator):
    elem = batch[0]
    collated = {}

    for f in dataclasses.fields(elem):
        elem_f = getattr(elem, f.name)
        if elem_f is None:
            collated[f.name] = None
        elif torch.is_tensor(elem_f):
            collated[f.name] = tensor_collator([getattr(e, f.name) for e in batch])
        elif dataclasses.is_dataclass(elem_f):
            collated[f.name] = cat_dataclass(
                [getattr(e, f.name) for e in batch], tensor_collator
            )
        elif isinstance(elem_f, collections.abc.Mapping):
            collated[f.name] = {
                k: tensor_collator([getattr(e, f.name)[k] for e in batch])
                if elem_f[k] is not None
                else None
                for k in elem_f
            }
        else:
            raise ValueError("Unsupported field type for concatenation")

    return type(elem)(**collated)