in mmf/common/sample.py [0:0]
def __init__(self, samples=None):
super().__init__(self)
if samples is None:
samples = []
if len(samples) == 0:
return
if self._check_and_load_dict(samples):
return
# If passed sample list was in form of key, value pairs of tuples
# return after loading these
if self._check_and_load_tuple(samples):
return
fields = samples[0].keys()
for field in fields:
if isinstance(samples[0][field], torch.Tensor):
size = (len(samples), *samples[0][field].size())
self[field] = samples[0][field].new_empty(size)
if self._get_tensor_field() is None:
self._set_tensor_field(field)
else:
self[field] = [None for _ in range(len(samples))]
for idx, sample in enumerate(samples):
# it should be a tensor but not a 0-d tensor
if (
isinstance(sample[field], torch.Tensor)
and len(sample[field].size()) != 0
and sample[field].size(0) != samples[0][field].size(0)
):
raise AssertionError(
"Fields for all samples must be equally sized. "
"{} is of different sizes".format(field)
)
self[field][idx] = self._get_data_copy(sample[field])
if isinstance(samples[0][field], collections.abc.Mapping):
self[field] = SampleList(self[field])