def parse()

in tinynn/converter/operators/torch/aten.py [0:0]


    def parse(self, node, attrs, args, graph_converter):
        super().parse(node, attrs, args, graph_converter)

        # torch.Tensor.index_put_ requires index tensor of type `torch.int64`
        accumulate = self.input_tensors[3]
        assert not accumulate, "aten::index_put_ with accumulate=True is not supported"

        orig_type = self.input_tensors[1][0].dtype
        self.input_tensors[1] = tuple([x.to(dtype=torch.int64) for x in self.input_tensors[1]])
        self.run(node)

        input_tensor = self.find_or_create_input(0, graph_converter)
        output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]

        self.input_tensors[1] = tuple([x.to(dtype=orig_type) for x in self.input_tensors[1]])

        if graph_converter.has_nested_names(self.input_names[1]):
            input_names = graph_converter.get_list_expanded_names(self.input_names[1])
            indices_tensors = self.to_tfl_tensors(
                input_names, self.input_tensors[1], graph_converter=graph_converter, non_existent_as_buffer=True
            )
        else:
            if type(self.input_tensors[1]) in (tuple, list):
                indices_tensors = [self.create_attr_tensor(x) for x in self.input_tensors[1]]
            else:
                indices_tensors = [self.find_or_create_input(1, graph_converter)]

        dim = input_tensor.tensor.ndim

        indices_shape = [x.tensor.size for x in indices_tensors]
        max_len = max(indices_shape)
        indices_shape_tensor = torch.tensor(indices_shape)
        left_indices = (torch.arange(max_len).view(-1, 1).expand(-1, len(indices_shape)) % indices_shape_tensor).int()

        if len(indices_tensors) < dim:
            pad_shape = list(input_tensor.shape[len(indices_tensors) :])
            pad_indices = torch.ones(pad_shape).nonzero().int()
            left_len = len(indices_shape)
            right_len = len(pad_shape)
            left_size = left_indices.size(0)
            right_size = pad_indices.size(0)
            left_reshaped = left_indices.view(-1, 1, left_len).expand(-1, right_size, left_len).reshape(-1, left_len)
            right_reshaped = pad_indices.view(1, -1, right_len).expand(left_size, -1, right_len).reshape(-1, right_len)
            all_indices = torch.cat([left_reshaped, right_reshaped], 1).unbind(1)
        else:
            all_indices = left_indices.unbind(1)

        new_indices = []
        for i in range(dim):
            if i < len(indices_tensors):
                idx_tensor = indices_tensors[i]
                actual_idx = np.take(idx_tensor.tensor, all_indices[i].numpy())
            else:
                actual_idx = all_indices[i].numpy()
            if idx_tensor.buffer is None and i < len(indices_tensors):
                actual_idx_t = self.create_transform_tensor(actual_idx)
                fake_idx_t = self.create_attr_tensor(all_indices[i].numpy())
                graph_converter.add_operator(tfl.GatherOperator([idx_tensor, fake_idx_t], [actual_idx_t], axis=0))

                if str(actual_idx_t.dtype) != 'int32':
                    index_casted = self.create_transform_tensor(actual_idx_t.tensor.astype('int32'))
                    graph_converter.add_operator(
                        tfl.CastOperator(
                            [actual_idx_t],
                            [index_casted],
                            tfl.numpy_tflite_dtype_mappings[str(actual_idx_t.dtype)],
                            tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
                        )
                    )
                    actual_idx_t = index_casted
                new_indices.append(actual_idx_t)
            else:
                new_indices.append(self.create_attr_tensor(actual_idx.astype(np.int32)))

        index_arr = np.stack([x.tensor for x in new_indices], 1)
        if all((x.buffer is not None for x in new_indices)):
            index_tensor = self.create_attr_tensor(index_arr)
        else:
            index_tensor = self.create_transform_tensor(index_arr)
            graph_converter.add_operator(tfl.PackOperator(new_indices, [index_tensor], dim, axis=1))

        val_tensor = self.find_or_create_input(2, graph_converter)
        actual_val = val_tensor
        orig_val_shape = val_tensor.shape
        target_val_shape = index_tensor.shape[:-1]
        if orig_val_shape != target_val_shape:
            if val_tensor.buffer is None:
                new_shape = orig_val_shape
                val_reshaped = val_tensor
                if len(target_val_shape) > len(orig_val_shape):
                    new_shape = [1] * (len(target_val_shape) - len(orig_val_shape)) + list(orig_val_shape)
                    new_shape_arr = np.array(new_shape, dtype='int32')
                    new_shape_tensor = self.create_attr_tensor(new_shape_arr)
                    reshaped = self.create_transform_tensor(np.reshape(val_tensor.tensor, new_shape_arr))
                    val_reshaped = reshaped
                    reshape_op = tfl.ReshapeOperator([val_tensor, new_shape_tensor], [reshaped], new_shape_arr)
                    reshape_op.extra_hints['direction'] = 'up'
                    graph_converter.add_operator(reshape_op)

                repeats = []
                for x, y in zip(new_shape, target_val_shape):
                    if x != y:
                        repeats.append(y // x)
                    else:
                        repeats.append(1)

                actual_val = self.create_transform_tensor(np.tile(val_reshaped.tensor, repeats))
                repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
                graph_converter.add_operator(tfl.TileOperator([val_reshaped, repeat_tensor], [actual_val]))
            else:
                actual_val = self.create_attr_tensor(np.broadcast_to(val_tensor.tensor, target_val_shape))

        shape_tensor = self.create_attr_tensor(np.array(input_tensor.shape, dtype='int32'))

        if input_tensor.buffer is None or index_tensor.buffer is None:
            old_val_tensor = self.create_transform_tensor(actual_val.tensor)
            graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, index_tensor], [old_val_tensor]))
        else:
            transformed_index = tuple(index_tensor.tensor[..., i] for i in range(index_tensor.shape[-1]))
            old_val_tensor = self.create_attr_tensor(input_tensor.tensor[transformed_index])

        if actual_val.buffer is None:
            update_tensor = self.create_transform_tensor(actual_val.tensor - old_val_tensor.tensor)
            graph_converter.add_operator(tfl.SubOperator([actual_val, old_val_tensor], [update_tensor]))
        else:
            update_tensor = self.create_attr_tensor(actual_val.tensor - old_val_tensor.tensor)

        updated_tensor = self.create_transform_tensor(input_tensor.tensor)
        graph_converter.add_operator(
            tfl.ScatterNdOperator([index_tensor, update_tensor, shape_tensor], [updated_tensor])
        )

        graph_converter.add_operator(tfl.AddOperator([input_tensor, updated_tensor], [output_tensor]))