in lib/Dialect/mhlo/IR/hlo_ops.cc [3126:3310]
ParseResult parseReduceOp(OpAsmParser& parser, OperationState& result) {
llvm::SMLoc loc = parser.getCurrentLocation();
Location currLocation = parser.getEncodedSourceLoc(loc);
// Parse the operands of reduce-op, this is a list of pair under the form:
// (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
// Each input to reduce is paired with its init value, even though in memory
// they are stored with the input first and the init values after.
SmallVector<OpAsmParser::OperandType, 2> operands;
SmallVector<OpAsmParser::OperandType, 2> initOperands;
do {
parser.parseOptionalComma();
if (parser.parseOptionalLParen()) break;
OpAsmParser::OperandType operand, initOperand;
if (parser.parseOperand(operand) || parser.parseKeyword("init") ||
parser.parseColon() || parser.parseOperand(initOperand) ||
parser.parseRParen())
return failure();
operands.push_back(operand);
initOperands.push_back(initOperand);
} while (true);
operands.append(initOperands);
// Check if we are parsing the compact version of reduce-op:
// mhlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
// else parse the "region-based" variant.
if (failed(parser.parseOptionalKeyword("applies"))) {
// Parse the inner-op dimensions, reduce-op's function-type and
// optional location.
SmallVector<int64_t> dimensions;
auto parseDim = [&]() -> ParseResult {
if (parser.parseInteger(dimensions.emplace_back())) return failure();
return success();
};
FunctionType reduceOpFntype;
if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
parser.parseEqual() ||
parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
parseDim) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType(reduceOpFntype) ||
parser.parseKeyword("reducer"))
return failure();
OpBuilder builder(parser.getBuilder().getContext());
result.addAttribute("dimensions", GetI64ElementsAttr(dimensions, &builder));
// Parse the "reducer" region now.
SmallVector<OpAsmParser::OperandType, 2> reducerOperands;
SmallVector<OpAsmParser::OperandType, 2> reducerInitOperands;
SmallVector<Type, 2> reducerTypes;
SmallVector<Type, 2> reducerInitTypes;
SmallVector<Optional<Location>, 2> reducerLocs;
SmallVector<Optional<Location>, 2> reducerInitLocs;
auto parseBlockOperand =
[&](SmallVectorImpl<OpAsmParser::OperandType>& operands,
SmallVectorImpl<Type>& types,
SmallVectorImpl<Optional<Location>>& locs) -> ParseResult {
OpAsmParser::OperandType operand;
Type type;
Optional<Location> loc;
if (parser.parseRegionArgument(operand) || parser.parseColon() ||
parser.parseType(type) || parser.parseOptionalLocationSpecifier(loc))
return failure();
operands.push_back(operand);
types.push_back(type);
locs.push_back(loc);
return success();
};
do {
if (failed(parser.parseOptionalLParen())) break;
if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) ||
parser.parseComma() ||
parseBlockOperand(reducerInitOperands, reducerInitTypes,
reducerInitLocs) ||
parser.parseRParen())
return failure();
} while (true);
reducerOperands.append(reducerInitOperands);
reducerTypes.append(reducerInitTypes);
reducerLocs.append(reducerInitLocs);
result.addTypes(reduceOpFntype.getResults());
// Derive the SSA-values for reduce-op's operands and parse the region, and
// the optional trailing location.
Optional<Location> trailingLoc;
if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
result.operands) ||
parser.parseRegion(*result.addRegion(), reducerOperands, reducerTypes))
return failure();
// Set the individual block arguments.
for (auto argAndLoc :
llvm::zip(result.regions.front()->front().getArguments(), reducerLocs))
if (std::get<1>(argAndLoc))
std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).getValue());
result.location = trailingLoc.getValueOr(currLocation);
return success();
}
// Parse the inner-op name and check if the contract on inner-op
// mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met.
FailureOr<OperationName> innerOpNameInfo = parser.parseCustomOperationName();
if (failed(innerOpNameInfo)) return failure();
StringRef innerOpName = innerOpNameInfo->getStringRef();
Dialect* innerOpDialect = innerOpNameInfo->getDialect();
if (!innerOpDialect || !innerOpDialect->getNamespace().equals("mhlo") ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::NOperands<2>::Impl>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::OneResult>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::SameOperandsAndResultType>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::IsCommutative>() ||
!innerOpNameInfo->hasTrait<mlir::OpTrait::ZeroRegion>()) {
parser.emitError(loc,
"expected the inner-op to be a commutative binary-op from "
"mhlo dialect, zero region, producing single result such "
"that the operands and result all have the same type");
return failure();
}
// Parse the inner-op dimensions, reduce-op's function-type and
// optional location.
SmallVector<int64_t> dimensions;
auto parseDim = [&]() -> ParseResult {
if (parser.parseInteger(dimensions.emplace_back())) return failure();
return success();
};
Optional<Location> explicitLoc;
FunctionType reduceOpFntype;
if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
parser.parseEqual() ||
parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
parser.parseColon() || parser.parseType(reduceOpFntype) ||
parser.parseOptionalLocationSpecifier(explicitLoc))
return failure();
if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) {
if (!reduceOpFntype) return parser.emitError(loc, "expected function type");
return parser.emitError(loc,
"input types missing in reduce-op function type");
}
// If location of reduce-op is explicitly provided, then use it; Else use
// the parser's current location.
Location reduceOpLoc = explicitLoc.getValueOr(currLocation);
// Derive the SSA-values for reduce-op's operands.
if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
result.operands))
return failure();
// Derive the type of inner-op from that of reduce-op's input operand.
auto innerOpType = RankedTensorType::get(
/*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0)));
// Add a region for reduce-op.
Region& region = *result.addRegion();
// Create a basic-block inside reduce-op's region.
Block& block = region.emplaceBlock();
auto lhs = block.addArgument(innerOpType, reduceOpLoc);
auto rhs = block.addArgument(innerOpType, reduceOpLoc);
// Create and insert an "inner-op" operation in the block.
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToStart(&block);
OperationState innerOpState(reduceOpLoc, innerOpName);
innerOpState.operands.push_back(lhs);
innerOpState.operands.push_back(rhs);
innerOpState.addTypes(innerOpType);
Operation* innerOp = builder.createOperation(innerOpState);
// Insert a return statement in the block returning the inner-op's result.
builder.create<ReturnOp>(innerOp->getLoc(), innerOp->getResults());
// Populate the reduce-op operation-state with result-type, location, and
// dimension attribute.
result.addTypes(reduceOpFntype.getResults());
result.location = innerOp->getLoc();
result.addAttribute("dimensions", GetI64ElementsAttr(dimensions, &builder));
return success();
}