in petastorm/pytorch.py [0:0]
def _sanitize_pytorch_types(row_as_dict):
"""Promotes values types in a dictionary to the types supported by pytorch. Raises an error if type is clear error
if the type can not be promoted.
The parameter is modified in-place.
int8, uint16 are promoted to int32; uint32 -> int64;
numpy string_, unicode_, object arrays are not supported.
:param dict[str,obj] row_as_dict: a dictionary of key-value pairs. The values types are promoted to
pytorch compatible.
:return: None
"""
for name, value in row_as_dict.items():
# PyTorch supported types are: double, float, float16, int64, int32, and uint8
if isinstance(value, np.ndarray):
if value.dtype == np.int8 and _TORCH_BEFORE_1_1:
row_as_dict[name] = value.astype(np.int16)
elif value.dtype == np.uint16:
row_as_dict[name] = value.astype(np.int32)
elif value.dtype == np.uint32:
row_as_dict[name] = value.astype(np.int64)
elif value.dtype == np.bool_:
row_as_dict[name] = value.astype(np.uint8)
elif re.search('[SaUO]', value.dtype.str):
raise TypeError('Pytorch does not support arrays of string or object classes. '
'Found in field {}.'.format(name))
elif isinstance(value, np.bool_):
row_as_dict[name] = np.uint8(value)
elif value is None:
raise TypeError('Pytorch does not support nullable fields. Found None in {}'.format(name))