in lib/Analysis/shape_component_analysis.cc [65:229]
void visit(ShapeOrValueInfo requestedInfo) {
backwards_worklist.push_back(requestedInfo);
// First, we climb up the operations so we get the set of all ops taking
// part in this shape or value computation. An alternative would be
// analyzing everything eagerly. This backwards pass allows us to be lazy.
while (!backwards_worklist.empty()) {
// Skip if already processed.
ShapeOrValueInfo transitivelyRequestedInfo =
backwards_worklist.pop_back_val();
if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;
// Skip irrelevant cases early.
Value value = transitivelyRequestedInfo.value();
Type ty = value.getType();
if (!ty.isIntOrIndexOrFloat() && !ty.isa<RankedTensorType>()) continue;
// Handle shapes.
if (transitivelyRequestedInfo.isShapeInfo()) {
if (value.getDefiningOp<shape::AssumingOp>()) {
backwardAssumingShape(value);
} else if (auto bcast =
value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
backwardDynamicBroadcastInDimShape(bcast);
} else if (auto reshape =
value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
backwardDynamicReshapeShape(reshape);
} else if (value.getDefiningOp<mhlo::ReduceOp>()) {
backwardReduceShape(value);
} else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
backwardTransposeShape(transpose);
} else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
backwardSelectShape(select);
} else if (auto arg = value.dyn_cast<BlockArgument>()) {
backwardBlockArgumentShape(arg);
} else if (value.getDefiningOp() &&
value.getDefiningOp()
->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
backwardSameOperandsAndResultShape(value);
} else {
backwardUnknownShape(value);
}
continue;
}
// Skip irrelevant cases early.
auto ranked_ty = ty.dyn_cast<RankedTensorType>();
bool is_possibly_interesting_scalar = ty.isIntOrIndex();
bool is_possibly_interesting_tensor =
ranked_ty && ranked_ty.getRank() <= 1 && ranked_ty.hasStaticShape();
if (!is_possibly_interesting_scalar && !is_possibly_interesting_tensor) {
continue;
}
// Handle values.
assert(transitivelyRequestedInfo.isValueInfo() &&
"Expect value info at this point.");
if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
backwardShapeOf(shapeof);
} else if (auto num_elements =
value.getDefiningOp<shape::NumElementsOp>()) {
backwardNumElements(num_elements);
} else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
backwardDim(dim);
} else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
backwardIndexCast(cast);
} else if (auto fromElements =
value.getDefiningOp<tensor::FromElementsOp>()) {
backwardTensorFromElements(fromElements);
} else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
backwardTensorExtract(extract);
} else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
backwardBinOp(add);
} else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
backwardBinOp(mul);
} else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
backwardBinOp(add);
} else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
backwardBinOp(mul);
} else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
backwardConcatenate(concat);
} else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
backwardReshape(reshape);
} else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
backwardSlice(slice);
} else if (matchPattern(value, m_Constant())) {
backwardConstant(value);
} else {
backwardUnknown(value);
}
}
// Second, we walk down from the defs to the uses, building symbolic
// expressions for shape and value components.
while (!forwards_worklist.empty()) {
auto transitivelyRequestedInfo = forwards_worklist.pop_back_val();
// Skip if already processed.
if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue;
// Handle shapes.
Value value = transitivelyRequestedInfo.value();
if (!transitivelyRequestedInfo.isValueInfo()) {
if (value.getDefiningOp<shape::AssumingOp>()) {
forwardAssumingShape(value);
} else if (auto broadcast =
value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) {
forwardDynamicBroadcastInDimShape(broadcast);
} else if (auto reshape =
value.getDefiningOp<mhlo::DynamicReshapeOp>()) {
forwardDynamicReshapeShape(reshape);
} else if (value.getDefiningOp<mhlo::ReduceOp>()) {
forwardReduceShape(value);
} else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) {
forwardTransposeShape(transpose);
} else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) {
forwardSelectShape(select);
} else if (value.getDefiningOp() &&
value.getDefiningOp()
->hasTrait<OpTrait::SameOperandsAndResultShape>()) {
forwardSameOperandsShape(value);
} else {
forwardUnknownShape(value);
}
continue;
}
// Handle values.
assert(transitivelyRequestedInfo.isValueInfo() &&
"Expect value info at this point.");
if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) {
forwardShapeOf(shapeof);
} else if (auto num_elements =
value.getDefiningOp<shape::NumElementsOp>()) {
forwardNumElements(num_elements);
} else if (auto dim = value.getDefiningOp<tensor::DimOp>()) {
forwardDim(dim);
} else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) {
forwardIndexCast(cast);
} else if (auto fromElements =
value.getDefiningOp<tensor::FromElementsOp>()) {
forwardTensorFromElements(fromElements);
} else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
forwardTensorExtract(extract);
} else if (auto add = value.getDefiningOp<mhlo::AddOp>()) {
forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
} else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) {
forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
} else if (auto add = value.getDefiningOp<arith::AddIOp>()) {
forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; });
} else if (auto mul = value.getDefiningOp<arith::MulIOp>()) {
forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; });
} else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) {
forwardConcatenate(concat);
} else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) {
forwardReshape(reshape);
} else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) {
forwardSlice(slice);
} else if (matchPattern(value, m_Constant())) {
forwardConstant(value);
} else {
forwardUnknown(value);
}
}
}