xla::XlaOp CreateIndexUpdate()

in torch_xla/csrc/xla_lower_util.cpp [549:605]


xla::XlaOp CreateIndexUpdate(
    xla::XlaOp buffer, xla::XlaOp indices, xla::int64_t start_dim,
    xla::XlaOp values,
    const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp)>& combiner) {
  const xla::Shape& buffer_shape = XlaHelpers::ShapeOfXlaOp(buffer);
  const xla::Shape& indices_shape = XlaHelpers::ShapeOfXlaOp(indices);
  const xla::Shape& values_shape = XlaHelpers::ShapeOfXlaOp(values);

  absl::Span<const xla::int64_t> indices_dims =
      xla::AsInt64Slice(indices_shape.dimensions());
  XLA_CHECK(!indices_dims.empty());
  // The minor dimension of indices contains the indices to update.
  xla::int64_t num_index_dims = indices_dims.back();
  indices_dims.remove_suffix(1);
  xla::ScatterDimensionNumbers dim_numbers;
  dim_numbers.set_index_vector_dim(indices_shape.rank() - 1);

  xla::int64_t values_rank = values_shape.rank();
  xla::int64_t buffer_rank = buffer_shape.rank();
  xla::int64_t num_window_dims_in_values = buffer_rank - num_index_dims;

  // Make the values match the rank expected by scatter.
  std::vector<xla::int64_t> expected_values_dims;
  for (xla::int64_t dim = 0; dim < start_dim; ++dim) {
    expected_values_dims.push_back(buffer_shape.dimensions(dim));
  }
  expected_values_dims.insert(expected_values_dims.end(), indices_dims.begin(),
                              indices_dims.end());
  for (xla::int64_t dim = num_index_dims + start_dim; dim < buffer_rank;
       ++dim) {
    expected_values_dims.push_back(buffer_shape.dimensions(dim));
  }
  xla::XlaOp new_values = values;
  if (buffer_shape.element_type() != values_shape.element_type()) {
    new_values = ConvertTo(new_values, values_shape.element_type(),
                           buffer_shape.element_type(), /*device=*/nullptr);
  }
  new_values = BuildExpand(new_values, expected_values_dims);
  const xla::Shape& new_values_shape = XlaHelpers::ShapeOfXlaOp(new_values);
  values_rank = new_values_shape.rank();

  for (xla::int64_t dim = 0; dim < start_dim; ++dim) {
    dim_numbers.add_update_window_dims(dim);
  }
  for (xla::int64_t i = values_rank - num_window_dims_in_values + start_dim;
       i < values_rank; ++i) {
    dim_numbers.add_update_window_dims(i);
  }
  for (xla::int64_t i = 0; i < num_index_dims; ++i) {
    dim_numbers.add_inserted_window_dims(i + start_dim);
    dim_numbers.add_scatter_dims_to_operand_dims(i + start_dim);
  }
  xla::XlaComputation combiner_computation =
      MakeScatterComputation(combiner, buffer_shape.element_type());
  return xla::Scatter(buffer, indices, new_values, combiner_computation,
                      dim_numbers);
}