def parse()

in tinynn/converter/operators/torch/aten.py [0:0]


    def parse(self, node, attrs, args, graph_converter):
        super().parse(node, attrs, args, graph_converter)

        self.run(node)
        indices = self.input_tensors[1]

        filtered_dims = [i for i, idx in enumerate(indices) if idx is not None]
        assert all((indices[i].dtype in (torch.int64, torch.int32) for i in filtered_dims))

        input_tensor = self.find_or_create_input(0, graph_converter)
        outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

        if len(filtered_dims) > 1:
            if graph_converter.has_nested_names(self.input_names[1]):
                input_names = graph_converter.get_list_expanded_names(self.input_names[1])
                indices_tensors = self.to_tfl_tensors(
                    input_names, self.input_tensors[1], graph_converter=graph_converter, non_existent_as_buffer=True
                )
            else:
                if type(self.input_tensors[1]) in (tuple, list):
                    indices_tensors = [self.create_attr_tensor(x) for x in self.input_tensors[1]]
                else:
                    indices_tensors = [self.find_or_create_input(1, graph_converter)]

            dim = input_tensor.tensor.ndim

            indices_shape = [x.tensor.size for x in indices_tensors]
            max_len = max(indices_shape)
            indices_shape_tensor = torch.tensor(indices_shape)
            left_indices = (
                torch.arange(max_len).view(-1, 1).expand(-1, len(indices_shape)) % indices_shape_tensor
            ).int()
            all_indices_shape = list(outputs[0].shape) + [dim]

            if len(indices_tensors) < dim:
                pad_shape = list(input_tensor.shape[len(indices_tensors) :])
                pad_indices = torch.ones(pad_shape).nonzero().int()
                left_len = len(indices_shape)
                right_len = len(pad_shape)
                left_size = left_indices.size(0)
                right_size = pad_indices.size(0)
                left_reshaped = (
                    left_indices.view(-1, 1, left_len).expand(-1, right_size, left_len).reshape(-1, left_len)
                )
                right_reshaped = (
                    pad_indices.view(1, -1, right_len).expand(left_size, -1, right_len).reshape(-1, right_len)
                )
                all_indices = torch.cat([left_reshaped, right_reshaped], 1).view(all_indices_shape).unbind(-1)
            else:
                all_indices = left_indices.view(all_indices_shape).unbind(-1)

            new_indices = []
            for i in range(dim):
                if i < len(indices_tensors):
                    idx_tensor = indices_tensors[i]
                    actual_idx = np.take(idx_tensor.tensor, all_indices[i].numpy())
                else:
                    actual_idx = all_indices[i].numpy()
                if idx_tensor.buffer is None and i < len(indices_tensors):
                    actual_idx_t = self.create_transform_tensor(actual_idx)
                    fake_idx_t = self.create_attr_tensor(all_indices[i].numpy())
                    graph_converter.add_operator(tfl.GatherOperator([idx_tensor, fake_idx_t], [actual_idx_t], axis=0))

                    if str(actual_idx_t.dtype) != 'int32':
                        index_casted = self.create_transform_tensor(actual_idx_t.tensor.astype('int32'))
                        graph_converter.add_operator(
                            tfl.CastOperator(
                                [actual_idx_t],
                                [index_casted],
                                tfl.numpy_tflite_dtype_mappings[str(actual_idx_t.dtype)],
                                tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
                            )
                        )
                        actual_idx_t = index_casted
                    new_indices.append(actual_idx_t)
                else:
                    new_indices.append(self.create_attr_tensor(actual_idx.astype(np.int32)))

            index_arr = np.stack([x.tensor for x in new_indices], -1)
            if all((x.buffer is not None for x in new_indices)):
                index_tensor = self.create_attr_tensor(index_arr)
            else:
                index_tensor = self.create_transform_tensor(index_arr)
                graph_converter.add_operator(
                    tfl.PackOperator(new_indices, [index_tensor], dim, axis=index_tensor.tensor.ndim - 1)
                )

            graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, index_tensor], outputs))
        else:
            try:
                names = graph_converter.get_list_expanded_names(self.input_names[1])
            except KeyError:
                names = [self.get_unique_attr_name() for _ in indices]

            filtered_names = [names[i] for i in filtered_dims]
            filtered_tensors = [indices[i].to(dtype=torch.int32) for i in filtered_dims]

            filtered_tensors = [
                t + (t < 0).int() * input_tensor.shape[i] if n not in graph_converter.tensor_map else t
                for i, n, t in zip(filtered_dims, filtered_names, filtered_tensors)
            ]
            indice_tensors = self.to_tfl_tensors(
                filtered_names, filtered_tensors, graph_converter=graph_converter, non_existent_as_buffer=True
            )

            actual_input = input_tensor
            actual_output = None
            for i, (dim, idx) in enumerate(zip(filtered_dims, indice_tensors)):
                if i == len(filtered_dims) - 1:
                    actual_output = outputs[0]
                else:
                    actual_output = self.create_transform_tensor(np.take(actual_input.tensor, idx.tensor, axis=dim))

                graph_converter.add_operator(tfl.GatherOperator([actual_input, idx], [actual_output], axis=dim))

                actual_input = actual_output