LogicalResult matchAndRewrite()

in lib/Dialect/lhlo/transforms/lhlo_legalize_to_affine.cc [244:464]


  LogicalResult matchAndRewrite(GatherOp op,
                                PatternRewriter& rewriter) const final {
    Location loc = op.getLoc();

    // Operand array.
    Value operand = op.operand();
    MemRefType operand_type = operand.getType().cast<MemRefType>();
    unsigned operand_rank = operand_type.getRank();
    ArrayRef<int64_t> operand_shape = operand_type.getShape();

    // Start_indices array.
    Value start_indices = op.start_indices();
    MemRefType start_indices_type = start_indices.getType().cast<MemRefType>();
    unsigned start_indices_rank = start_indices_type.getRank();
    ArrayRef<int64_t> start_indices_shape = start_indices_type.getShape();

    // Output array.
    Value output = op.output();
    MemRefType output_type = output.getType().cast<MemRefType>();
    ArrayRef<int64_t> output_shape = output_type.getShape();

    if (!operand_type.hasStaticShape() ||
        !start_indices_type.hasStaticShape() || !output_type.hasStaticShape())
      return rewriter.notifyMatchFailure(op, "only static shaped type allowed");

    mhlo::GatherDimensionNumbersAttr gather_dim = op.dimension_numbers();

    auto collapsed_slice_dims = gather_dim.getCollapsedSliceDims();
    auto offset_dims = gather_dim.getOffsetDims();
    auto start_index_map = gather_dim.getStartIndexMap();
    int64_t index_vector_dim = gather_dim.getIndexVectorDim();

    // Slice_sizes.
    DenseIntElementsAttr slice_sizes_attr = op.slice_sizesAttr();
    SmallVector<int64_t, 4> slice_sizes;
    for (const APInt& dim : slice_sizes_attr.getValues<APInt>())
      slice_sizes.push_back(dim.getSExtValue());

    // Creating constants with 0 value. We need the Integer type constant value
    // because the indices type will be Integer.
    Value zero_int_val = rewriter.create<mlir::arith::ConstantIntOp>(
        loc, 0, start_indices_type.getElementType());
    Type element_type = output_type.getElementType();
    Value zero_load_value = getZeroValue(element_type, loc, rewriter);
    // Initializing the output buffer with 0.
    fillBuffer(loc, output, zero_load_value, rewriter);

    // We fetch the shape of start_indices at index_vector_dim. In case
    // index_vector_dim is equal to the rank of start_indices, we implicitly
    // consider start_indices to have a trailing 1 dimension.
    unsigned start_indices_numbers =
        (index_vector_dim == start_indices_rank)
            ? 1
            : start_indices_shape[index_vector_dim];
    // We create integer constants till start_incides_index which help us to
    // fetch start_indices in affine transformation.
    SmallVector<Value, 4> start_indices_index;
    for (unsigned i = 0; i < start_indices_numbers; i++) {
      Value i_val = rewriter.create<mlir::arith::ConstantIntOp>(
          loc, i, start_indices_type.getElementType());
      i_val = rewriter.create<arith::IndexCastOp>(loc, i_val,
                                                  rewriter.getIndexType());
      start_indices_index.push_back(i_val);
    }

    // S_in contains the multiple indices that form a starting index used in the
    // input/operand tensor. O_in contains the multiple offsets of corresponding
    // starting index used in the input/operand tensor. We initialize both of
    // them with 0.
    SmallVector<Value, 4> S_in;
    SmallVector<Value, 4> O_in;
    Value zero_index_val = rewriter.create<arith::IndexCastOp>(
        loc, zero_int_val, rewriter.getIndexType());
    for (unsigned i = 0; i < operand_rank; i++) {
      S_in.push_back(zero_index_val);
      O_in.push_back(zero_index_val);
    }

    // batch_induction_vars stores the loop induction variables pertaining to
    // the batches of start indices.
    SmallVector<Value, 4> batch_induction_vars;
    // output_induction_vars stores the loop induction variables pertaining to
    // both batches and offsets within the output tensor.
    SmallVector<Value, 4> output_induction_vars;
    // Create loops to iterate over each batch of starting index.
    for (unsigned i = 0; i < start_indices_rank; i++) {
      // ith dimension of start_indices doesn't form a batch if it is equal to
      // index_vector_dim.
      if (i == index_vector_dim) continue;
      AffineForOp for_op =
          rewriter.create<AffineForOp>(loc, 0, start_indices_shape[i]);
      batch_induction_vars.push_back(for_op.getInductionVar());
      output_induction_vars.push_back(for_op.getInductionVar());
      rewriter.setInsertionPointToStart(for_op.getBody());
    }

    // Create loops to iterate over each offset dimension within the output
    // tensor.
    for (unsigned i = 0, k = 0, e = offset_dims.size(); i < e; i++) {
      AffineForOp for_op =
          rewriter.create<AffineForOp>(loc, 0, output_shape[offset_dims[i]]);
      rewriter.setInsertionPointToStart(for_op.getBody());
      // We try to fetch the first non-collapsed dimension.
      while (k < collapsed_slice_dims.size() && collapsed_slice_dims[k] == i)
        k++;
      // Remapping the offset_dim[i] to the non-collapsed dimension.
      O_in[k++] = for_op.getInductionVar();
      // We assume offset_dims to be sorted. So when inserted to
      // output_induction_vars the loop induction variable gets inserted at the
      // correct position.
      output_induction_vars.insert(
          output_induction_vars.begin() + offset_dims[i],
          for_op.getInductionVar());
    }

    // Create loops to iterate over all dimensions within the operand tensor.
    SmallVector<Value, 4> operand_index;
    for (unsigned i = 0, k = 0; i < operand_rank; i++) {
      // We assume start_index_map to have sorted dimensions. We only include
      // those dimensions of operand tensor which are present in
      // start_index_map.
      if (k < start_index_map.size() && i == start_index_map[k++]) {
        AffineForOp for_op =
            rewriter.create<AffineForOp>(loc, 0, operand_shape[i]);
        operand_index.push_back(for_op.getInductionVar());
        rewriter.setInsertionPointToStart(for_op.getBody());
      } else {
        operand_index.push_back(O_in[i]);
      }
    }

    // In case index_vector_dim is not equal to start_indices shape then we
    // create another loop to iterate over starting index and update
    // batch_induction_vars.
    if (index_vector_dim != start_indices_rank) {
      for (unsigned i = 0; i < start_indices_numbers; i++) {
        batch_induction_vars.insert(
            batch_induction_vars.begin() + index_vector_dim,
            start_indices_index[i]);
        Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
                                                          batch_induction_vars);
        start_index = rewriter.create<arith::IndexCastOp>(
            loc, start_index, rewriter.getIndexType());
        S_in[start_index_map[i]] = start_index;
        batch_induction_vars.erase(batch_induction_vars.begin() +
                                   index_vector_dim);
      }
    } else {
      // Since index_vector_dim is equal to start_indicesRank we can directly
      // fetch the start_index from batch_induction_vars.
      Value start_index = rewriter.create<AffineLoadOp>(loc, start_indices,
                                                        batch_induction_vars);
      start_index = rewriter.create<arith::IndexCastOp>(
          loc, start_index, rewriter.getIndexType());
      S_in[0] = start_index;
    }

    // We load value at a particular operand index and populate the output
    // tensor if the index constraints match.
    Value load_value =
        rewriter.create<AffineLoadOp>(loc, operand, operand_index);
    SmallVector<Value, 4> predicates;
    // Adding offsets to the corresponding starting index and comparing it with
    // the corresponding operand index.
    for (unsigned k = 0, i = 0; k < start_index_map.size(); k++) {
      i = start_index_map[k];
      Value add_start_index_offset = rewriter.create<mlir::arith::AddIOp>(
          loc, rewriter.getIndexType(), S_in[i], O_in[i]);
      Value predicate = rewriter.create<mlir::arith::CmpIOp>(
          loc, arith::CmpIPredicate::eq, add_start_index_offset,
          operand_index[i]);
      predicates.push_back(predicate);
    }

    // Since the no. of predicates is equal to start_index_map.size() we
    // iterate over pairs of predicates and join them with arith::AndIOp.
    // We store the final predicate formed by joining other predicates with
    // arith::AndIOp in result_predicate.
    Value result_predicate = nullptr;
    for (unsigned i = 0; i < predicates.size() - 1; i += 2) {
      Value predicateA = predicates[i];
      Value predicateB = predicates[i + 1];
      Value and_predicate =
          rewriter.create<mlir::arith::AndIOp>(loc, predicateA, predicateB);
      result_predicate = (i == 0) ? and_predicate
                                  : rewriter.create<mlir::arith::AndIOp>(
                                        loc, result_predicate, and_predicate);
    }
    // We fetch the last predicate value. In case this is the only predicate
    // we let result_predicate be equal to this predicate value. Else if there
    // are odd number of predicates we join it to other predicates using
    // arith::AndIOp.
    Value predicate = predicates.back();
    if (!result_predicate) result_predicate = predicate;
    // In case there are odd number of predicates we join the last predicate
    // to the result_predicate using arith::AndIOp.
    else if (start_index_map.size() % 2 == 1)
      result_predicate = rewriter.create<mlir::arith::AndIOp>(
          loc, result_predicate, predicate);

    // We use the loaded value if the index computed by adding offsets to
    // starting index is equal to the current operand index. We use 0 as a value
    // otherwise.
    Value select_load = rewriter.create<mlir::SelectOp>(
        loc, result_predicate, load_value, zero_load_value);
    // We load value at output array.
    Value output_value =
        rewriter.create<AffineLoadOp>(loc, output, output_induction_vars);

    // The selected value is added to the previous value stored in output array.
    if (element_type.isa<FloatType>())
      output_value = rewriter.create<arith::AddFOp>(loc, element_type,
                                                    select_load, output_value);
    else
      output_value = rewriter.create<arith::AddIOp>(loc, element_type,
                                                    select_load, output_value);
    rewriter.create<AffineStoreOp>(loc, output_value, output,
                                   output_induction_vars);
    rewriter.eraseOp(op);
    return success();
  }