in lib/Dialect/lhlo/transforms/lhlo_legalize_to_affine.cc [244:464]
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter& rewriter) const final {
Location loc = op.getLoc();
// Operand array.
Value operand = op.operand();
MemRefType operand_type = operand.getType().cast<MemRefType>();
unsigned operand_rank = operand_type.getRank();
ArrayRef<int64_t> operand_shape = operand_type.getShape();
// Start_indices array.
Value start_indices = op.start_indices();
MemRefType start_indices_type = start_indices.getType().cast<MemRefType>();
unsigned start_indices_rank = start_indices_type.getRank();
ArrayRef<int64_t> start_indices_shape = start_indices_type.getShape();
// Output array.
Value output = op.output();
MemRefType output_type = output.getType().cast<MemRefType>();
ArrayRef<int64_t> output_shape = output_type.getShape();
if (!operand_type.hasStaticShape() ||
!start_indices_type.hasStaticShape() || !output_type.hasStaticShape())
return rewriter.notifyMatchFailure(op, "only static shaped type allowed");
mhlo::GatherDimensionNumbersAttr gather_dim = op.dimension_numbers();
auto collapsed_slice_dims = gather_dim.getCollapsedSliceDims();
auto offset_dims = gather_dim.getOffsetDims();
auto start_index_map = gather_dim.getStartIndexMap();
int64_t index_vector_dim = gather_dim.getIndexVectorDim();
// Slice_sizes.
DenseIntElementsAttr slice_sizes_attr = op.slice_sizesAttr();
SmallVector<int64_t, 4> slice_sizes;
for (const APInt& dim : slice_sizes_attr.getValues<APInt>())
slice_sizes.push_back(dim.getSExtValue());
// Creating constants with 0 value. We need the Integer type constant value
// because the indices type will be Integer.
Value zero_int_val = rewriter.create<mlir::arith::ConstantIntOp>(
loc, 0, start_indices_type.getElementType());
Type element_type = output_type.getElementType();
Value zero_load_value = getZeroValue(element_type, loc, rewriter);
// Initializing the output buffer with 0.
fillBuffer(loc, output, zero_load_value, rewriter);
// We fetch the shape of start_indices at index_vector_dim. In case
// index_vector_dim is equal to the rank of start_indices, we implicitly
// consider start_indices to have a trailing 1 dimension.
unsigned start_indices_numbers =
(index_vector_dim == start_indices_rank)
? 1
: start_indices_shape[index_vector_dim];
// We create integer constants till start_incides_index which help us to
// fetch start_indices in affine transformation.
SmallVector<Value, 4> start_indices_index;
for (unsigned i = 0; i < start_indices_numbers; i++) {
Value i_val = rewriter.create<mlir::arith::ConstantIntOp>(
loc, i, start_indices_type.getElementType());
i_val = rewriter.create<arith::IndexCastOp>(loc, i_val,
rewriter.getIndexType());
start_indices_index.push_back(i_val);
}
// S_in contains the multiple indices that form a starting index used in the
// input/operand tensor. O_in contains the multiple offsets of corresponding
// starting index used in the input/operand tensor. We initialize both of
// them with 0.
SmallVector<Value, 4> S_in;
SmallVector<Value, 4> O_in;
Value zero_index_val = rewriter.create<arith::IndexCastOp>(
loc, zero_int_val, rewriter.getIndexType());
for (unsigned i = 0; i < operand_rank; i++) {
S_in.push_back(zero_index_val);
O_in.push_back(zero_index_val);
}
// batch_induction_vars stores the loop induction variables pertaining to
// the batches of start indices.
SmallVector<Value, 4> batch_induction_vars;
// output_induction_vars stores the loop induction variables pertaining to
// both batches and offsets within the output tensor.
SmallVector<Value, 4> output_induction_vars;
// Create loops to iterate over each batch of starting index.
for (unsigned i = 0; i < start_indices_rank; i++) {
// ith dimension of start_indices doesn't form a batch if it is equal to
// index_vector_dim.
if (i == index_vector_dim) continue;
AffineForOp for_op =
rewriter.create<AffineForOp>(loc, 0, start_indices_shape[i]);
batch_induction_vars.push_back(for_op.getInductionVar());
output_induction_vars.push_back(for_op.getInductionVar());
rewriter.setInsertionPointToStart(for_op.getBody());
}
// Create loops to iterate over each offset dimension within the output
// tensor.
for (unsigned i = 0, k = 0, e = offset_dims.size(); i < e; i++) {
AffineForOp for_op =
rewriter.create<AffineForOp>(loc, 0, output_shape[offset_dims[i]]);
rewriter.setInsertionPointToStart(for_op.getBody());
// We try to fetch the first non-collapsed dimension.
while (k < collapsed_slice_dims.size() && collapsed_slice_dims[k] == i)
k++;
// Remapping the offset_dim[i] to the non-collapsed dimension.
O_in[k++] = for_op.getInductionVar();
// We assume offset_dims to be sorted. So when inserted to
// output_induction_vars the loop induction variable gets inserted at the
// correct position.
output_induction_vars.insert(
output_induction_vars.begin() + offset_dims[i],
for_op.getInductionVar());
}
// Create loops to iterate over all dimensions within the operand tensor.
SmallVector<Value, 4> operand_index;
for (unsigned i = 0, k = 0; i < operand_rank; i++) {
// We assume start_index_map to have sorted dimensions. We only include
// those dimensions of operand tensor which are present in
// start_index_map.
if (k < start_index_map.size() && i == start_index_map[k++]) {
AffineForOp for_op =
rewriter.create<AffineForOp>(loc, 0, operand_shape[i]);
operand_index.push_back(for_op.getInductionVar());
rewriter.setInsertionPointToStart(for_op.getBody());
} else {
operand_index.push_back(O_in[i]);
}
}
// In case index_vector_dim is not equal to start_indices shape then we
// create another loop to iterate over starting index and update
// batch_induction_vars.
if (index_vector_dim != start_indices_rank) {
for (unsigned i = 0; i < start_indices_numbers; i++) {
batch_induction_vars.insert(
batch_induction_vars.begin() + index_vector_dim,
start_indices_index[i]);
Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
batch_induction_vars);
start_index = rewriter.create<arith::IndexCastOp>(
loc, start_index, rewriter.getIndexType());
S_in[start_index_map[i]] = start_index;
batch_induction_vars.erase(batch_induction_vars.begin() +
index_vector_dim);
}
} else {
// Since index_vector_dim is equal to start_indicesRank we can directly
// fetch the start_index from batch_induction_vars.
Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
batch_induction_vars);
start_index = rewriter.create<arith::IndexCastOp>(
loc, start_index, rewriter.getIndexType());
S_in[0] = start_index;
}
// We load value at a particular operand index and populate the output
// tensor if the index constraints match.
Value load_value =
rewriter.create<AffineLoadOp>(loc, operand, operand_index);
SmallVector<Value, 4> predicates;
// Adding offsets to the corresponding starting index and comparing it with
// the corresponding operand index.
for (unsigned k = 0, i = 0; k < start_index_map.size(); k++) {
i = start_index_map[k];
Value add_start_index_offset = rewriter.create<mlir::arith::AddIOp>(
loc, rewriter.getIndexType(), S_in[i], O_in[i]);
Value predicate = rewriter.create<mlir::arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, add_start_index_offset,
operand_index[i]);
predicates.push_back(predicate);
}
// Since the no. of predicates is equal to start_index_map.size() we
// iterate over pairs of predicates and join them with arith::AndIOp.
// We store the final predicate formed by joining other predicates with
// arith::AndIOp in result_predicate.
Value result_predicate = nullptr;
for (unsigned i = 0; i < predicates.size() - 1; i += 2) {
Value predicateA = predicates[i];
Value predicateB = predicates[i + 1];
Value and_predicate =
rewriter.create<mlir::arith::AndIOp>(loc, predicateA, predicateB);
result_predicate = (i == 0) ? and_predicate
: rewriter.create<mlir::arith::AndIOp>(
loc, result_predicate, and_predicate);
}
// We fetch the last predicate value. In case this is the only predicate
// we let result_predicate be equal to this predicate value. Else if there
// are odd number of predicates we join it to other predicates using
// arith::AndIOp.
Value predicate = predicates.back();
if (!result_predicate) result_predicate = predicate;
// In case there are odd number of predicates we join the last predicate
// to the result_predicate using arith::AndIOp.
else if (start_index_map.size() % 2 == 1)
result_predicate = rewriter.create<mlir::arith::AndIOp>(
loc, result_predicate, predicate);
// We use the loaded value if the index computed by adding offsets to
// starting index is equal to the current operand index. We use 0 as a value
// otherwise.
Value select_load = rewriter.create<mlir::SelectOp>(
loc, result_predicate, load_value, zero_load_value);
// We load value at output array.
Value output_value =
rewriter.create<AffineLoadOp>(loc, output, output_induction_vars);
// The selected value is added to the previous value stored in output array.
if (element_type.isa<FloatType>())
output_value = rewriter.create<arith::AddFOp>(loc, element_type,
select_load, output_value);
else
output_value = rewriter.create<arith::AddIOp>(loc, element_type,
select_load, output_value);
rewriter.create<AffineStoreOp>(loc, output_value, output,
output_induction_vars);
rewriter.eraseOp(op);
return success();
}