def parse_common()

in tinynn/converter/operators/torch/aten.py [0:0]


    def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, out_idx=0):
        for i in (input_idx, other_idx):
            t = self.input_tensors[i]
            if type(t) is torch.Tensor:
                if t.dtype == torch.float64:
                    self.input_tensors[i] = t.to(dtype=torch.float32)
                elif t.dtype == torch.int64:
                    self.input_tensors[i] = t.to(dtype=torch.int32)

        if self.output_tensors[out_idx].dtype == torch.float64:
            self.output_tensors[out_idx] = self.output_tensors[out_idx].to(dtype=torch.float32)
        elif self.output_tensors[out_idx].dtype == torch.int64:
            self.output_tensors[out_idx] = self.output_tensors[out_idx].to(dtype=torch.int32)

        mask = self.input_tensors[mask_idx]
        other = self.input_tensors[other_idx]
        out = self.output_tensors[out_idx]

        input_tensor, mask_tensor = [self.find_or_create_input(i, graph_converter) for i in (input_idx, mask_idx)]

        ops = []
        if type(other) is torch.Tensor:
            other_t = self.find_or_create_input(other_idx, graph_converter)
            if out.dtype != other.dtype:
                casted = other.clone().to(dtype=out.dtype)
                if other_t.buffer is None:
                    new_other = self.create_transform_tensor(casted)
                    ops.append(
                        tfl.CastOperator(
                            [other_t],
                            [new_other],
                            tfl.torch_tflite_dtype_mappings[other.dtype],
                            tfl.torch_tflite_dtype_mappings[out.dtype],
                        )
                    )
                    other_t = new_other
                    # TODO: +/- inf check for variable tensors
                else:
                    if hasattr(torch.functional, 'atleast_1d'):
                        casted = torch.functional.atleast_1d(casted)
                    elif len(casted.shape) == 0:
                        casted = casted.reshape(1)
                    if torch.isinf(casted).any():
                        log.warning(
                            'aten::masked_fill(input, mask, value) where value=[+/-]inf is not supported, '
                            'trying to convert it to the nearest value'
                        )
                        type_info = torch.finfo(casted.dtype)
                        clamped = torch.clamp(casted, type_info.min, type_info.max)
                        other_t = self.create_attr_tensor(clamped, name=self.input_names[other_idx])
                    else:
                        other_t = self.create_attr_tensor(casted, name=self.input_names[other_idx])
        elif type(other) in (int, float):
            other_a = np.array([other], dtype=self.input_tensors[input_idx].detach().numpy().dtype)
            if np.isinf(other_a).any():
                log.warning(
                    'aten::masked_fill(input, mask, value) where value=[+/-]inf is not supported, '
                    'trying to convert it to the nearest value'
                )
                type_info = np.finfo(other_a.dtype)
                other_a = np.clip(other_a, type_info.min, type_info.max)
            other_t = self.create_attr_tensor(other_a)
        else:
            assert False, "value should have type float, tensor in aten::masked_fill(input, mask, value)"

        if mask_tensor.buffer is None:
            input_mask = self.create_transform_tensor(mask_tensor.tensor.astype(input_tensor.dtype))
            ops.append(
                tfl.CastOperator(
                    [mask_tensor],
                    [input_mask],
                    tfl.torch_tflite_dtype_mappings[mask.dtype],
                    tfl.torch_tflite_dtype_mappings[out.dtype],
                )
            )
        else:
            input_mask = self.create_attr_tensor(mask_tensor.tensor.astype(input_tensor.dtype))

        if mask_tensor.buffer is None or other_t.buffer is None:
            masked = self.create_transform_tensor(other_t.tensor * mask_tensor.tensor)
            ops.append(tfl.MulOperator([other_t, input_mask], [masked]))
        else:
            masked = self.create_attr_tensor(other_t.tensor * mask_tensor.tensor)

        one_tensor = self.create_attr_tensor(np.array([1], dtype=input_tensor.dtype))
        if mask_tensor.buffer is None:
            rev_mask = self.create_transform_tensor(one_tensor.tensor - mask_tensor.tensor)
            ops.append(tfl.SubOperator([one_tensor, input_mask], [rev_mask]))
        else:
            rev_mask = self.create_attr_tensor(one_tensor.tensor - mask_tensor.tensor)

        non_masked = self.create_transform_tensor(input_tensor.tensor * rev_mask.tensor)
        ops.append(tfl.MulOperator([input_tensor, rev_mask], [non_masked]))

        outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
        ops.append(tfl.AddOperator([non_masked, masked], outputs))

        for op in ops:
            graph_converter.add_operator(op)