in petastorm/pytorch.py [0:0]
def decimal_friendly_collate(batch):
"""A wrapper on top of ``default_collate`` function that allows decimal.Decimal types to be collated.
We use ``decimal.Decimal`` types in petastorm dataset to represent timestamps. PyTorch's ``default_collate``
implementation does not support collating ``decimal.Decimal`` types. ``decimal_friendly_collate`` collates
``decimal.Decimal`` separately and then combines with the rest of the fields collated by a standard
``default_collate``.
:param batch: A list of dictionaries to collate
:return: A dictionary of lists/pytorch.Tensor types
"""
if isinstance(batch[0], decimal.Decimal):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], _string_classes):
return batch
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [decimal_friendly_collate(samples) for samples in transposed]
else:
return default_collate(batch)