in lib/Dialect/mhlo/IR/hlo_ops.cc [5439:5543]
OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
if (!base || !index || !update) return {};
auto base_type = base.getType().dyn_cast<RankedTensorType>();
auto index_type = index.getType().dyn_cast<RankedTensorType>();
auto update_type = update.getType().dyn_cast<RankedTensorType>();
if (!base_type || !index_type || !update_type) return {};
// Add the virtual trailing dimension of size 1 if index_vector_dim equals to
// index_type.rank.
const int64_t index_vector_dim =
scatter_dimension_numbers().getIndexVectorDim();
if (index_vector_dim == index_type.getRank()) {
auto index_shape = index_type.getShape().vec();
index_shape.push_back(1);
index_type =
RankedTensorType::get(index_shape, index_type.getElementType());
index = index.reshape(index_type).cast<DenseIntElementsAttr>();
}
// Increment the multi-dimensional index vector based on the limits for each
// dimension specified by shape and returns false if the index rolled around
// with true otherwise.
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
++index[i];
if (index[i] < shape[i]) return true;
index[i] = 0;
}
return false;
};
// Iterate over all elements of the update tensor, then find the corresponding
// value in the indices tensor to determine which location we have to update
// in the base/result tensor.
llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
llvm::SmallVector<uint64_t, 8> index_index;
index_index.reserve(index_type.getRank());
llvm::SmallVector<uint64_t, 8> base_index;
base_index.reserve(base_type.getRank());
do {
// Compute the index for the slice of the indices tensor for this update
// value.
index_index.clear();
if (index_vector_dim == 0) index_index.push_back(0);
for (int64_t i = 0; i < update_index.size(); ++i) {
if (llvm::count(scatter_dimension_numbers().getUpdateWindowDims(), i) ==
0)
index_index.push_back(update_index[i]);
if (index_index.size() == index_vector_dim) index_index.push_back(0);
}
// Compute the index for the given update value in the base tensor.
base_index.assign(base_type.getRank(), 0);
uint64_t index_count = index_type.getShape()[index_vector_dim];
for (uint64_t i = 0; i < index_count; ++i) {
uint64_t operand_dim =
scatter_dimension_numbers().getScatterDimsToOperandDims()[i];
index_index[index_vector_dim] = i;
base_index[operand_dim] +=
index.getValues<APInt>()[index_index].getSExtValue();
}
uint64_t update_window_dim_index = 0;
auto inserted_window_dims =
scatter_dimension_numbers().getInsertedWindowDims();
auto update_window_dims = scatter_dimension_numbers().getUpdateWindowDims();
for (uint64_t i = 0; i < base_index.size(); ++i) {
if (llvm::count(inserted_window_dims, i)) continue;
base_index[i] +=
update_index[update_window_dims[update_window_dim_index]];
update_window_dim_index++;
}
// Compute the linear index for the index into the base tensor.
int64_t linear_base_index = 0;
int64_t linear_base_index_multiplyer = 1;
for (int64_t i = base_index.size() - 1; i >= 0; --i) {
// Out of bound index have backend specific behaviour so avoid folding it.
if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
return {};
linear_base_index += base_index[i] * linear_base_index_multiplyer;
linear_base_index_multiplyer *= base_type.getShape()[i];
}
// Evaluate update computation and update the value with the newly computed
// attribute in the base tensor.
auto lhs = DenseElementsAttr::get(
RankedTensorType::get({}, base_type.getElementType()),
results[linear_base_index]);
auto rhs = DenseElementsAttr::get(
RankedTensorType::get({}, base_type.getElementType()),
update.getValues<Attribute>()[update_index]);
auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
if (new_value.size() != 1 || !new_value[0]) return {};
results[linear_base_index] =
new_value[0].cast<DenseElementsAttr>().getValues<Attribute>()[0];
} while (next_index(update_index, update_type.getShape()));
return DenseElementsAttr::get(base_type, results);
}