void PropagateBroadcast()

in lib/Dialect/mhlo/transforms/broadcast_propagation.cc [165:294]


void PropagateBroadcast(DynamicBroadcastInDimOp root) {
  OpBuilder builder(root.getContext());
  BroadcastIntent root_bcast_intent = {
      root.getResult().getType().cast<RankedTensorType>(), root.operand(),
      root.output_dimensions(), root.broadcast_dimensions()};

  // We can move broadcasts up over (broadcasting) element-wise operations and
  // propagate them through the IR to perform them early. Instead of
  // broadcasting the result of such an op, we can broadcast the operands and
  // apply the element-wise operation to them.
  //
  // To avoid exponential growth of the IR, we will do this in two phases:
  //   1) First, we collect all the unique broadcast intents. These are
  //      broadcasted versions of values that we are interested in. They may
  //      later be materialized as an explicit broadcast or they can be the
  //      direct result of an operation over which a broadcast was propagated.
  //   2) Then, we fulfill every broadcast intent in reverse topological order
  //      to ensure that their dependencies (the broadcasted operands) are
  //      available.

  // Collect all the broadcast intents, starting with the root. Record
  // dependencies for later lookups.
  DenseSet<BroadcastIntent> bcast_intents = {root_bcast_intent};
  SmallVector<BroadcastIntent> bcast_intents_ordered = {root_bcast_intent};
  DenseMap<BroadcastIntent, SmallVector<BroadcastIntent>>
      bcast_propagation_dependencies;
  Block *the_block = root->getBlock();

  // We use the ordered broadcast intents as a worklist, the first `i` intents
  // of which have been processed.
  auto empty_broadcast_dimensions = builder.getI64TensorAttr({});
  for (int i = 0; i < bcast_intents_ordered.size(); ++i) {
    BroadcastIntent it = bcast_intents_ordered[i];
    Operation *producer_op = it.target_value.getDefiningOp();

    // We can propagate broadcasts over (broadcasting) element-wise operations
    // with the restriction that they must be in the same block as they may
    // depend on assumptions.
    if (producer_op && producer_op->getBlock() == the_block &&
        AllowsForBroadcastPropagation(producer_op)) {
      // Collect this broadcast propagation's dependencies: the broadcasted
      // versions of the operands that we will need in the second phase.
      SmallVector<BroadcastIntent> dependencies;
      for (auto operand : producer_op->getOperands()) {
        auto operand_ty = operand.getType().cast<RankedTensorType>();
        auto operand_broadcast_dimensions = operand_ty.getRank() == 0
                                                ? empty_broadcast_dimensions
                                                : it.broadcast_dimensions;
        auto bcasted_operand_ty = RankedTensorType::get(
            it.result_type.getShape(), operand_ty.getElementType());
        BroadcastIntent bcasted_operand_intent = {bcasted_operand_ty, operand,
                                                  it.output_dimensions,
                                                  operand_broadcast_dimensions};
        dependencies.push_back(bcasted_operand_intent);

        // If this broadcast intent was not yet seen, add it to the worklist.
        // Otherwise, we know it will be fulfilled earlier.
        if (!bcast_intents.count(bcasted_operand_intent)) {
          bcast_intents_ordered.push_back(bcasted_operand_intent);
          bcast_intents.insert(bcasted_operand_intent);
        }
      }
      bcast_propagation_dependencies[it] = dependencies;
    }
  }

  // Realize all the broadcast intents in reverse topological order of the
  // producer ops. We can use the positions in the block for this. All broadcast
  // intents outside the block (e.g. arguments) will be sorted towards the
  // front.
  // This ordering is independent of the output dimensions as dependencies can
  // only occur between broadcast intents of the same output dimension.
  DenseMap<BroadcastIntent, Value> realizations;
  std::sort(bcast_intents_ordered.begin(), bcast_intents_ordered.end(),
            [&](const BroadcastIntent &a, const BroadcastIntent &b) {
              Operation *producer_op_a = a.target_value.getDefiningOp();
              Operation *producer_op_b = b.target_value.getDefiningOp();
              bool a_in_block = producer_op_a != nullptr &&
                                producer_op_a->getBlock() == the_block;
              bool b_in_block = producer_op_b != nullptr &&
                                producer_op_b->getBlock() == the_block;
              if (a_in_block && b_in_block) {
                return producer_op_a->isBeforeInBlock(producer_op_b);
              }
              return !a_in_block && b_in_block;
            });
  for (auto it : bcast_intents_ordered) {
    Operation *producer_op = it.target_value.getDefiningOp();

    // Realize broadcast intent for an element-wise operation based on the
    // broadcasted operands, if possible.
    if (bcast_propagation_dependencies.count(it)) {
      assert(producer_op && producer_op->getBlock() == the_block &&
             AllowsForBroadcastPropagation(producer_op) &&
             "expect (broadcasting)  element-wise op in the same block");
      auto bcasted_operands =
          llvm::to_vector(llvm::map_range(bcast_propagation_dependencies[it],
                                          [&](BroadcastIntent operand_intent) {
                                            return realizations[operand_intent];
                                          }));
      SetInsertionPointToEarliestPointWithAllValuesAvailable(builder, the_block,
                                                             bcasted_operands);
      OperationState new_producer_op_state(
          producer_op->getLoc(), producer_op->getName().getStringRef(),
          bcasted_operands, it.result_type, producer_op->getAttrs());
      Operation *new_producer_op =
          builder.createOperation(new_producer_op_state);
      assert(new_producer_op->getNumResults() == 1 &&
             "expect exactly one result");
      realizations[it] = new_producer_op->getResults().front();
      continue;
    }

    // Fall back to explicit broadcasts, otherwise.
    SetInsertionPointToEarliestPointWithAllValuesAvailable(
        builder, the_block, ValueRange{it.target_value, it.output_dimensions});
    realizations[it] = builder.create<DynamicBroadcastInDimOp>(
        it.target_value.getLoc(), it.result_type, it.target_value,
        it.output_dimensions,
        it.broadcast_dimensions.cast<DenseIntElementsAttr>());
  }

  // Lookup the replacement for the root operation.
  auto replacement = realizations[root_bcast_intent];
  root->replaceAllUsesWith(ValueRange{replacement});

  // Erase all the operations that have become redundant as a result of this
  // rewrite.
  TransitivelyEraseUnusedSideEffectFreeOps(root);
}