in tinynn/converter/operators/torch/aten.py [0:0]
def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, out_idx=0):
for i in (input_idx, other_idx):
t = self.input_tensors[i]
if type(t) is torch.Tensor:
if t.dtype == torch.float64:
self.input_tensors[i] = t.to(dtype=torch.float32)
elif t.dtype == torch.int64:
self.input_tensors[i] = t.to(dtype=torch.int32)
if self.output_tensors[out_idx].dtype == torch.float64:
self.output_tensors[out_idx] = self.output_tensors[out_idx].to(dtype=torch.float32)
elif self.output_tensors[out_idx].dtype == torch.int64:
self.output_tensors[out_idx] = self.output_tensors[out_idx].to(dtype=torch.int32)
mask = self.input_tensors[mask_idx]
other = self.input_tensors[other_idx]
out = self.output_tensors[out_idx]
input_tensor, mask_tensor = [self.find_or_create_input(i, graph_converter) for i in (input_idx, mask_idx)]
ops = []
if type(other) is torch.Tensor:
other_t = self.find_or_create_input(other_idx, graph_converter)
if out.dtype != other.dtype:
casted = other.clone().to(dtype=out.dtype)
if other_t.buffer is None:
new_other = self.create_transform_tensor(casted)
ops.append(
tfl.CastOperator(
[other_t],
[new_other],
tfl.torch_tflite_dtype_mappings[other.dtype],
tfl.torch_tflite_dtype_mappings[out.dtype],
)
)
other_t = new_other
# TODO: +/- inf check for variable tensors
else:
if hasattr(torch.functional, 'atleast_1d'):
casted = torch.functional.atleast_1d(casted)
elif len(casted.shape) == 0:
casted = casted.reshape(1)
if torch.isinf(casted).any():
log.warning(
'aten::masked_fill(input, mask, value) where value=[+/-]inf is not supported, '
'trying to convert it to the nearest value'
)
type_info = torch.finfo(casted.dtype)
clamped = torch.clamp(casted, type_info.min, type_info.max)
other_t = self.create_attr_tensor(clamped, name=self.input_names[other_idx])
else:
other_t = self.create_attr_tensor(casted, name=self.input_names[other_idx])
elif type(other) in (int, float):
other_a = np.array([other], dtype=self.input_tensors[input_idx].detach().numpy().dtype)
if np.isinf(other_a).any():
log.warning(
'aten::masked_fill(input, mask, value) where value=[+/-]inf is not supported, '
'trying to convert it to the nearest value'
)
type_info = np.finfo(other_a.dtype)
other_a = np.clip(other_a, type_info.min, type_info.max)
other_t = self.create_attr_tensor(other_a)
else:
assert False, "value should have type float, tensor in aten::masked_fill(input, mask, value)"
if mask_tensor.buffer is None:
input_mask = self.create_transform_tensor(mask_tensor.tensor.astype(input_tensor.dtype))
ops.append(
tfl.CastOperator(
[mask_tensor],
[input_mask],
tfl.torch_tflite_dtype_mappings[mask.dtype],
tfl.torch_tflite_dtype_mappings[out.dtype],
)
)
else:
input_mask = self.create_attr_tensor(mask_tensor.tensor.astype(input_tensor.dtype))
if mask_tensor.buffer is None or other_t.buffer is None:
masked = self.create_transform_tensor(other_t.tensor * mask_tensor.tensor)
ops.append(tfl.MulOperator([other_t, input_mask], [masked]))
else:
masked = self.create_attr_tensor(other_t.tensor * mask_tensor.tensor)
one_tensor = self.create_attr_tensor(np.array([1], dtype=input_tensor.dtype))
if mask_tensor.buffer is None:
rev_mask = self.create_transform_tensor(one_tensor.tensor - mask_tensor.tensor)
ops.append(tfl.SubOperator([one_tensor, input_mask], [rev_mask]))
else:
rev_mask = self.create_attr_tensor(one_tensor.tensor - mask_tensor.tensor)
non_masked = self.create_transform_tensor(input_tensor.tensor * rev_mask.tensor)
ops.append(tfl.MulOperator([input_tensor, rev_mask], [non_masked]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.AddOperator([non_masked, masked], outputs))
for op in ops:
graph_converter.add_operator(op)