tinynn/converter/operators/torch/aten.py [3221:3261]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
        dim, index = self.input_tensors[1:3]
        if dim < 0:
            dim += input_tensor.tensor.ndim

        fake_input = torch.arange(input_tensor.tensor.size).reshape(input_tensor.shape)
        fake_output = torch.gather(fake_input, dim, index)

        indices = torch.nonzero(fake_input >= 0)[fake_output].to(dtype=torch.int32)

        self.input_tensors[2] = self.input_tensors[2].to(dtype=orig_type)
        index_tensor = self.find_or_create_input(2, graph_converter)
        if index_tensor.buffer is None:
            indices_per_dim = torch.split(indices, 1, dim=-1)
            indices_tensors = [self.create_attr_tensor(t) for t in indices_per_dim]

            index_shape = list(index_tensor.shape) + [1]
            axis = len(index_shape) - 1
            shape_tensor = self.create_attr_tensor(np.array(index_shape, dtype='int32'))
            index_reshaped = self.create_transform_tensor(np.reshape(index_tensor.tensor, index_shape))
            reshape_op = tfl.ReshapeOperator([index_tensor, shape_tensor], [index_reshaped], index_shape)
            reshape_op.extra_hints['direction'] = 'up'
            graph_converter.add_operator(reshape_op)

            if str(index_reshaped.dtype) != 'int32':
                index_casted = self.create_transform_tensor(index_reshaped.tensor.astype('int32'))
                graph_converter.add_operator(
                    tfl.CastOperator(
                        [index_reshaped],
                        [index_casted],
                        tfl.numpy_tflite_dtype_mappings[str(index_reshaped.dtype)],
                        tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
                    )
                )
                index_reshaped = index_casted

            indices_tensors[dim] = index_reshaped
            indices_tensor = self.create_transform_tensor(np.concatenate([x.tensor for x in indices_tensors], axis=-1))
            graph_converter.add_operator(tfl.ConcatenationOperator(indices_tensors, [indices_tensor], axis=axis))
        else:
            indices_tensor = self.create_attr_tensor(indices)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



tinynn/converter/operators/torch/aten.py [3283:3323]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
        dim, index = self.input_tensors[1:3]
        if dim < 0:
            dim += input_tensor.tensor.ndim

        fake_input = torch.arange(input_tensor.tensor.size).reshape(input_tensor.shape)
        fake_output = torch.gather(fake_input, dim, index)

        indices = torch.nonzero(fake_input >= 0)[fake_output].to(dtype=torch.int32)

        self.input_tensors[2] = self.input_tensors[2].to(dtype=orig_type)
        index_tensor = self.find_or_create_input(2, graph_converter)
        if index_tensor.buffer is None:
            indices_per_dim = torch.split(indices, 1, dim=-1)
            indices_tensors = [self.create_attr_tensor(t) for t in indices_per_dim]

            index_shape = list(index_tensor.shape) + [1]
            axis = len(index_shape) - 1
            shape_tensor = self.create_attr_tensor(np.array(index_shape, dtype='int32'))
            index_reshaped = self.create_transform_tensor(np.reshape(index_tensor.tensor, index_shape))
            reshape_op = tfl.ReshapeOperator([index_tensor, shape_tensor], [index_reshaped], index_shape)
            reshape_op.extra_hints['direction'] = 'up'
            graph_converter.add_operator(reshape_op)

            if str(index_reshaped.dtype) != 'int32':
                index_casted = self.create_transform_tensor(index_reshaped.tensor.astype('int32'))
                graph_converter.add_operator(
                    tfl.CastOperator(
                        [index_reshaped],
                        [index_casted],
                        tfl.numpy_tflite_dtype_mappings[str(index_reshaped.dtype)],
                        tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
                    )
                )
                index_reshaped = index_casted

            indices_tensors[dim] = index_reshaped
            indices_tensor = self.create_transform_tensor(np.concatenate([x.tensor for x in indices_tensors], axis=-1))
            graph_converter.add_operator(tfl.ConcatenationOperator(indices_tensors, [indices_tensor], axis=axis))
        else:
            indices_tensor = self.create_attr_tensor(indices)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



