in tensorflow/tensorflow/compiler/xla/service/hlo_parser.cc [2734:2983]
bool HloParser::ParseAttributeHelper(
const std::unordered_map<string, AttrConfig>& attrs,
std::unordered_set<string>* seen_attrs) {
LocTy loc = lexer_.GetLoc();
string name;
if (!ParseAttributeName(&name)) {
return Error(loc, "error parsing attributes");
}
VLOG(3) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
return Error(loc, StrFormat("attribute %s already exists", name));
}
auto attr_it = attrs.find(name);
if (attr_it == attrs.end()) {
string allowed_attrs;
if (attrs.empty()) {
allowed_attrs = "No attributes are allowed here.";
} else {
allowed_attrs = StrCat(
"Allowed attributes: ",
StrJoin(attrs, ", ",
[&](string* out, const std::pair<string, AttrConfig>& kv) {
StrAppend(out, kv.first);
}));
}
return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
allowed_attrs));
}
AttrTy attr_type = attr_it->second.attr_type;
void* attr_out_ptr = attr_it->second.result;
bool success = [&] {
LocTy attr_loc = lexer_.GetLoc();
switch (attr_type) {
case AttrTy::kBool: {
bool result;
if (!ParseBool(&result)) {
return false;
}
static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kInt64: {
int64 result;
if (!ParseInt64(&result)) {
return false;
}
static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kInt32: {
int64 result;
if (!ParseInt64(&result)) {
return false;
}
if (result != static_cast<int32>(result)) {
return Error(attr_loc, "value out of range for int32");
}
static_cast<optional<int32>*>(attr_out_ptr)
->emplace(static_cast<int32>(result));
return true;
}
case AttrTy::kFloat: {
double result;
if (!ParseDouble(&result)) {
return false;
}
if (result > std::numeric_limits<float>::max() ||
result < std::numeric_limits<float>::lowest()) {
return Error(attr_loc, "value out of range for float");
}
static_cast<optional<float>*>(attr_out_ptr)
->emplace(static_cast<float>(result));
return true;
}
case AttrTy::kHloComputation: {
HloComputation* result = nullptr;
if (!ParseHloComputation(&result)) {
return false;
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kBracedHloComputationList: {
std::vector<HloComputation*> result;
if (!ParseHloComputationList(&result)) {
return false;
}
static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kFftType: {
FftType result;
if (!ParseFftType(&result)) {
return false;
}
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kComparisonDirection: {
ComparisonDirection result;
if (!ParseComparisonDirection(&result)) {
return false;
}
static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kWindow: {
Window result;
if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
return false;
}
static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kConvolutionDimensionNumbers: {
ConvolutionDimensionNumbers result;
if (!ParseConvolutionDimensionNumbers(&result)) {
return false;
}
static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kSharding: {
OpSharding sharding;
if (!ParseSharding(&sharding)) {
return false;
}
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
return true;
}
case AttrTy::kParameterReplication: {
ParameterReplication parameter_replication;
if (!ParseParameterReplication(¶meter_replication)) {
return false;
}
static_cast<optional<ParameterReplication>*>(attr_out_ptr)
->emplace(parameter_replication);
return true;
}
case AttrTy::kInstructionList: {
std::vector<HloInstruction*> result;
if (!ParseInstructionNames(&result)) {
return false;
}
static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kFusionKind: {
HloInstruction::FusionKind result;
if (!ParseFusionKind(&result)) {
return false;
}
static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kBracedInt64List: {
std::vector<int64> result;
if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
&result)) {
return false;
}
static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kBracedInt64ListList: {
std::vector<std::vector<int64>> result;
if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
TokKind::kComma, &result)) {
return false;
}
static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kSliceRanges: {
SliceRanges result;
if (!ParseSliceRanges(&result)) {
return false;
}
static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kPaddingConfig: {
PaddingConfig result;
if (!ParsePaddingConfig(&result)) {
return false;
}
static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kString: {
string result;
if (!ParseString(&result)) {
return false;
}
static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kMetadata: {
OpMetadata result;
if (!ParseMetadata(&result)) {
return false;
}
static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kDistribution: {
RandomDistribution result;
if (!ParseRandomDistribution(&result)) {
return false;
}
static_cast<optional<RandomDistribution>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kDomain: {
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
case AttrTy::kPrecisionList: {
std::vector<PrecisionConfig::Precision> result;
if (!ParsePrecisionList(&result)) {
return false;
}
static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kShapeList: {
std::vector<Shape> result;
if (!ParseShapeList(&result)) {
return false;
}
static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
->emplace(result);
return true;
}
}
}();
if (!success) {
return Error(loc, StrFormat("error parsing attribute %s", name));
}
return true;
}