in Sources/x10/xla_tensor/xla_lower_util.cpp [209:276]
xla::XlaOp XlaDenseScatter(xla::XlaOp input, xla::XlaOp index, xla::XlaOp src,
xla::int64 dim, const ScatterOptions& options) {
// Contribute back this code to xla::TorchScatterDense() once this has reached
// a stable implementation.
xla::XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
const xla::Shape& index_shape = XlaHelpers::ShapeOfXlaOp(index);
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
std::vector<xla::int64> index_broacast_dims;
std::vector<xla::int64> sizes;
for (xla::int64 i = 0; i < index_shape.rank(); ++i) {
if (i < dim) {
index_broacast_dims.push_back(i);
} else {
if (i == dim) {
sizes.push_back(input_shape.dimensions(i));
}
index_broacast_dims.push_back(i + 1);
}
sizes.push_back(index_shape.dimensions(i));
}
xla::XlaOp init_value =
options.init_value
? *options.init_value
: xla::Zero(input.builder(), input_shape.element_type());
xla::XlaComputation reduce_computation =
options.combiner != nullptr
? MakeScatterComputation(options.combiner,
input_shape.element_type())
: xla::CreateScalarIdentityWithZeroComputation(
input_shape.element_type(), builder);
xla::XlaOp mask = xla::Eq(
xla::BroadcastInDim(index, sizes, index_broacast_dims),
xla::Iota(builder,
xla::ShapeUtil::MakeShape(index_shape.element_type(), sizes),
dim));
xla::XlaOp selected_src =
xla::Select(mask, xla::BroadcastInDim(src, sizes, index_broacast_dims),
xla::Broadcast(init_value, sizes));
xla::XlaOp masked_src =
xla::Reduce(selected_src, init_value, reduce_computation, {dim + 1});
if (options.indices_are_unique &&
XlaHelpers::SameStaticDimensions(index_shape, input_shape)) {
// If the index shape is the same as the input shape, the input shape will
// be fully covered (since scatter indices must be unique), so there is no
// need for masking.
return options.combiner != nullptr ? options.combiner(input, masked_src)
: masked_src;
}
xla::XlaOp reduced_mask = xla::Reduce(
mask, xla::ConstantR0<bool>(builder, false),
xla::CreateScalarOrComputation(xla::PrimitiveType::PRED, builder),
{dim + 1});
if (ScatterRequiresPadding(input_shape, index_shape, dim)) {
masked_src = PadToSize(masked_src, input_shape.dimensions(), init_value);
reduced_mask = PadToSize(reduced_mask, input_shape.dimensions());
}
xla::XlaOp result;
if (options.combiner != nullptr) {
result =
xla::Select(reduced_mask, options.combiner(input, masked_src), input);
} else {
result = xla::Select(reduced_mask, masked_src, input);
}
return result;
});
}