void visit()

in lib/Analysis/shape_component_analysis.cc [65:229]


  void visit(ShapeOrValueInfo requestedInfo) {
    backwards_worklist.push_back(requestedInfo);

    // First, we climb up the operations so we get the set of all ops taking
    // part in this shape or value computation. An alternative would be
    // analyzing everything eagerly. This backwards pass allows us to be lazy.
    while (!backwards_worklist.empty()) {
      // Skip if already processed.
      ShapeOrValueInfo transitivelyRequestedInfo =
          backwards_worklist.pop_back_val();
      if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;

      // Skip irrelevant cases early.
      Value value = transitivelyRequestedInfo.value();
      Type ty = value.getType();
      if (!ty.isIntOrIndexOrFloat() && !ty.isa<RankedTensorType>()) continue;

      // Handle shapes.
      if (transitivelyRequestedInfo.isShapeInfo()) {
        if (value.getDefiningOp<shape::AssumingOp>()) {
          backwardAssumingShape(value);
        } else if (auto bcast =
                       value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
          backwardDynamicBroadcastInDimShape(bcast);
        } else if (auto reshape =
                       value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
          backwardDynamicReshapeShape(reshape);
        } else if (value.getDefiningOp<mhlo::ReduceOp>()) {
          backwardReduceShape(value);
        } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
          backwardTransposeShape(transpose);
        } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
          backwardSelectShape(select);
        } else if (auto arg = value.dyn_cast<BlockArgument>()) {
          backwardBlockArgumentShape(arg);
        } else if (value.getDefiningOp() &&
                   value.getDefiningOp()
                       ->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
          backwardSameOperandsAndResultShape(value);
        } else {
          backwardUnknownShape(value);
        }
        continue;
      }

      // Skip irrelevant cases early.
      auto ranked_ty = ty.dyn_cast<RankedTensorType>();
      bool is_possibly_interesting_scalar = ty.isIntOrIndex();
      bool is_possibly_interesting_tensor =
          ranked_ty && ranked_ty.getRank() <= 1 && ranked_ty.hasStaticShape();
      if (!is_possibly_interesting_scalar && !is_possibly_interesting_tensor) {
        continue;
      }

      // Handle values.
      assert(transitivelyRequestedInfo.isValueInfo() &&
             "Expect value info at this point.");
      if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
        backwardShapeOf(shapeof);
      } else if (auto num_elements =
                     value.getDefiningOp<shape::NumElementsOp>()) {
        backwardNumElements(num_elements);
      } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
        backwardDim(dim);
      } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
        backwardIndexCast(cast);
      } else if (auto fromElements =
                     value.getDefiningOp<tensor::FromElementsOp>()) {
        backwardTensorFromElements(fromElements);
      } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
        backwardTensorExtract(extract);
      } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
        backwardBinOp(add);
      } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
        backwardBinOp(mul);
      } else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
        backwardBinOp(add);
      } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
        backwardBinOp(mul);
      } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
        backwardConcatenate(concat);
      } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
        backwardReshape(reshape);
      } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
        backwardSlice(slice);
      } else if (matchPattern(value, m_Constant())) {
        backwardConstant(value);
      } else {
        backwardUnknown(value);
      }
    }

    // Second, we walk down from the defs to the uses, building symbolic
    // expressions for shape and value components.
    while (!forwards_worklist.empty()) {
      auto transitivelyRequestedInfo = forwards_worklist.pop_back_val();

      // Skip if already processed.
      if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;

      // Handle shapes.
      Value value = transitivelyRequestedInfo.value();
      if (!transitivelyRequestedInfo.isValueInfo()) {
        if (value.getDefiningOp<shape::AssumingOp>()) {
          forwardAssumingShape(value);
        } else if (auto broadcast =
                       value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
          forwardDynamicBroadcastInDimShape(broadcast);
        } else if (auto reshape =
                       value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
          forwardDynamicReshapeShape(reshape);
        } else if (value.getDefiningOp<mhlo::ReduceOp>()) {
          forwardReduceShape(value);
        } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
          forwardTransposeShape(transpose);
        } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
          forwardSelectShape(select);
        } else if (value.getDefiningOp() &&
                   value.getDefiningOp()
                       ->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
          forwardSameOperandsShape(value);
        } else {
          forwardUnknownShape(value);
        }
        continue;
      }

      // Handle values.
      assert(transitivelyRequestedInfo.isValueInfo() &&
             "Expect value info at this point.");
      if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
        forwardShapeOf(shapeof);
      } else if (auto num_elements =
                     value.getDefiningOp<shape::NumElementsOp>()) {
        forwardNumElements(num_elements);
      } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
        forwardDim(dim);
      } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
        forwardIndexCast(cast);
      } else if (auto fromElements =
                     value.getDefiningOp<tensor::FromElementsOp>()) {
        forwardTensorFromElements(fromElements);
      } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
        forwardTensorExtract(extract);
      } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
        forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
      } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
        forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
      } else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
        forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
      } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
        forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
      } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
        forwardConcatenate(concat);
      } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
        forwardReshape(reshape);
      } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
        forwardSlice(slice);
      } else if (matchPattern(value, m_Constant())) {
        forwardConstant(value);
      } else {
        forwardUnknown(value);
      }
    }
  }