in lib/Dialect/mhlo/IR/hlo_ops.cc [6206:6388]
ParseResult parseConvolutionDimensions(AsmParser& parser,
ConvDimensionNumbersAttr& dnums) {
// Parsing a single set of dim numbers gives the spatial dimensions as a
// single ArrayRef<int64_t> and a list of non-spatial dimensions as
// IntegerAttrs (indexed by the NonSpatialDim enum).
using parse_dim_result_t =
std::pair<llvm::SmallVector<int64_t>,
llvm::SmallDenseMap<NonSpatialDim, int64_t, 4,
DenseMapInfoNonSpatialDim>>;
// Note that the allowed_non_spatial_dims is a set (as opposed to unordered
// set) because its used to print a list of allowed non spatial dims in the
// error messages, so making it a set keeps the error messages deterministic.
auto parse_dims =
[&](std::set<NonSpatialDim, std::greater<>> allowed_non_spatial_dims,
parse_dim_result_t& parsed_dims) -> ParseResult {
auto& spatial_dims = std::get<0>(parsed_dims);
auto& non_spatial_dims = std::get<1>(parsed_dims);
spatial_dims.clear();
non_spatial_dims.clear();
// Parse the starting [
if (parser.parseLSquare()) {
return failure();
}
llvm::SmallDenseMap<int64_t, int64_t> spatial_dims_map;
constexpr int64_t kInvalidDimension = -1;
// Keep track of the maximum spatial dimension parsed as we expect to see
// all the dimensions from 0 to maximum dimension parsed.
int64_t max_parsed_spatial_dim = kInvalidDimension;
int64_t index = 0;
do {
int64_t spatial_dim;
auto dim_location = parser.getCurrentLocation();
OptionalParseResult parseResult =
parser.parseOptionalInteger(spatial_dim);
if (parseResult.hasValue()) {
if (parseResult.getValue().failed()) {
return failure();
}
// We were successful in parsing an integer. Check if it is a valid
// dimension (non-negative and no duplicate) and add its index to the
// spatial dims map.
if (spatial_dim < 0)
return parser.emitError(dim_location)
<< "Unexpected dimension " << spatial_dim;
if (!spatial_dims_map
.insert(std::pair<int64_t, int64_t>(spatial_dim, index))
.second)
return parser.emitError(dim_location)
<< "Duplicate entries for spatial dimension " << spatial_dim;
max_parsed_spatial_dim = std::max(spatial_dim, max_parsed_spatial_dim);
} else if (!parser.parseOptionalQuestion()) {
// Do nothing other than increment `index` at the bottom of the loop;
// '?' means "unknown dimension", and it's not represented in the
// return value of this function.
} else {
// We did not parse an integer or question mark. We expect a keyword
// token.
StringRef keyword;
if (parser.parseKeyword(&keyword)) {
return failure();
}
if (keyword.size() != 1 || allowed_non_spatial_dims.empty()) {
return parser.emitError(dim_location, "Unexpected keyword ")
<< keyword;
}
// Check if the keyword matches one of the allowed non-spatial dims.
// If so, add it to the non_spatial dims and remove it from the
// allowed set so that it won't be allowed again.
bool is_allowed = false;
for (NonSpatialDim allowed : allowed_non_spatial_dims) {
if (keyword[0] == NonSpatialDimToString(allowed)) {
non_spatial_dims.insert({allowed, index});
allowed_non_spatial_dims.erase(allowed);
is_allowed = true;
break;
}
}
if (!is_allowed) {
mlir::InFlightDiagnostic diag =
parser.emitError(dim_location, "Unexpected dimension ");
diag << keyword << ", expecting ";
llvm::interleaveComma(
allowed_non_spatial_dims, diag,
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
return diag;
}
}
index++;
} while (parser.parseOptionalComma().succeeded());
// Make sure all expected non-spatial dimensions are parsed.
if (!allowed_non_spatial_dims.empty()) {
mlir::InFlightDiagnostic diag =
parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
llvm::interleaveComma(
allowed_non_spatial_dims, diag,
[&](NonSpatialDim dim) { diag << NonSpatialDimToString(dim); });
diag << " not specified";
return diag;
}
// parse ending ]
if (parser.parseRSquare()) {
return failure();
}
// Number of expected spatial dimensions is one more than the maximum parsed
// spatial dimension. For example, if we parse [0, 3, 2, b, i, 1], then the
// maximum parsed spatial dimension is 3 and the number of expected spatial
// dimensions is 4.
int64_t num_spatial_dimensions = max_parsed_spatial_dim + 1;
spatial_dims.resize(num_spatial_dimensions);
// Store spatial dimensions in a vector which maps spatial dim (vector
// index) -> index in the tensor dimensions. For example, for parsed
// dimension numbers [0, 3, 2, b, i, 1] the spatial dimension vector would
// be [0, 5, 2, 1].
//
// Get all the unspecified spatial dimensions to throw a more descriptive
// error later.
llvm::SmallVector<int64_t> unspecified_spatial_dims;
constexpr int kPrintUnspecifiedDimsMax = 10;
for (int dim = 0; dim < num_spatial_dimensions; ++dim) {
auto it = spatial_dims_map.find(dim);
if (it == spatial_dims_map.end()) {
// Have an upper bound on the number of unspecified dimensions to print
// in the error message.
if (unspecified_spatial_dims.size() < kPrintUnspecifiedDimsMax)
unspecified_spatial_dims.push_back(dim);
continue;
}
spatial_dims[dim] = it->second;
}
// Verify that we got all spatial dimensions between 0 and maximum parsed
// spatial dimension.
if (!unspecified_spatial_dims.empty()) {
mlir::InFlightDiagnostic diag = parser.emitError(
parser.getCurrentLocation(), "Expected spatial dimensions ");
llvm::interleaveComma(unspecified_spatial_dims, diag);
diag << " not specified";
return diag;
}
return success();
};
parse_dim_result_t parsed_dims;
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
return failure();
}
llvm::SmallVector<int64_t> input_spatial_dimensions = parsed_dims.first;
int64_t input_batch_dimension = parsed_dims.second[IOBatch];
int64_t input_feature_dimension = parsed_dims.second[IOFeature];
if (parser.parseKeyword("x")) return failure();
if (parse_dims({KIFeature, KOFeature}, parsed_dims)) {
return failure();
}
llvm::SmallVector<int64_t> kernel_spatial_dimensions = parsed_dims.first;
int64_t kernel_input_feature_dimension = parsed_dims.second[KIFeature];
int64_t kernel_output_feature_dimension = parsed_dims.second[KOFeature];
if (parser.parseArrow()) {
return failure();
}
if (parse_dims({IOBatch, IOFeature}, parsed_dims)) {
return failure();
}
llvm::SmallVector<int64_t> output_spatial_dimensions = parsed_dims.first;
int64_t output_batch_dimension = parsed_dims.second[IOBatch];
int64_t output_feature_dimension = parsed_dims.second[IOFeature];
dnums = ConvDimensionNumbersAttr::get(
parser.getBuilder().getContext(), input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions);
return success();
}