in torch_xla/csrc/xla_lower_util.cpp [549:605]
xla::XlaOp CreateIndexUpdate(
xla::XlaOp buffer, xla::XlaOp indices, xla::int64_t start_dim,
xla::XlaOp values,
const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp)>& combiner) {
const xla::Shape& buffer_shape = XlaHelpers::ShapeOfXlaOp(buffer);
const xla::Shape& indices_shape = XlaHelpers::ShapeOfXlaOp(indices);
const xla::Shape& values_shape = XlaHelpers::ShapeOfXlaOp(values);
absl::Span<const xla::int64_t> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
XLA_CHECK(!indices_dims.empty());
// The minor dimension of indices contains the indices to update.
xla::int64_t num_index_dims = indices_dims.back();
indices_dims.remove_suffix(1);
xla::ScatterDimensionNumbers dim_numbers;
dim_numbers.set_index_vector_dim(indices_shape.rank() - 1);
xla::int64_t values_rank = values_shape.rank();
xla::int64_t buffer_rank = buffer_shape.rank();
xla::int64_t num_window_dims_in_values = buffer_rank - num_index_dims;
// Make the values match the rank expected by scatter.
std::vector<xla::int64_t> expected_values_dims;
for (xla::int64_t dim = 0; dim < start_dim; ++dim) {
expected_values_dims.push_back(buffer_shape.dimensions(dim));
}
expected_values_dims.insert(expected_values_dims.end(), indices_dims.begin(),
indices_dims.end());
for (xla::int64_t dim = num_index_dims + start_dim; dim < buffer_rank;
++dim) {
expected_values_dims.push_back(buffer_shape.dimensions(dim));
}
xla::XlaOp new_values = values;
if (buffer_shape.element_type() != values_shape.element_type()) {
new_values = ConvertTo(new_values, values_shape.element_type(),
buffer_shape.element_type(), /*device=*/nullptr);
}
new_values = BuildExpand(new_values, expected_values_dims);
const xla::Shape& new_values_shape = XlaHelpers::ShapeOfXlaOp(new_values);
values_rank = new_values_shape.rank();
for (xla::int64_t dim = 0; dim < start_dim; ++dim) {
dim_numbers.add_update_window_dims(dim);
}
for (xla::int64_t i = values_rank - num_window_dims_in_values + start_dim;
i < values_rank; ++i) {
dim_numbers.add_update_window_dims(i);
}
for (xla::int64_t i = 0; i < num_index_dims; ++i) {
dim_numbers.add_inserted_window_dims(i + start_dim);
dim_numbers.add_scatter_dims_to_operand_dims(i + start_dim);
}
xla::XlaComputation combiner_computation =
MakeScatterComputation(combiner, buffer_shape.element_type());
return xla::Scatter(buffer, indices, new_values, combiner_computation,
dim_numbers);
}