bool HloParser::ParseAttributeHelper()

in tensorflow/tensorflow/compiler/xla/service/hlo_parser.cc [2734:2983]


bool HloParser::ParseAttributeHelper(
    const std::unordered_map<string, AttrConfig>& attrs,
    std::unordered_set<string>* seen_attrs) {
  LocTy loc = lexer_.GetLoc();
  string name;
  if (!ParseAttributeName(&name)) {
    return Error(loc, "error parsing attributes");
  }
  VLOG(3) << "Parsing attribute " << name;
  if (!seen_attrs->insert(name).second) {
    return Error(loc, StrFormat("attribute %s already exists", name));
  }
  auto attr_it = attrs.find(name);
  if (attr_it == attrs.end()) {
    string allowed_attrs;
    if (attrs.empty()) {
      allowed_attrs = "No attributes are allowed here.";
    } else {
      allowed_attrs = StrCat(
          "Allowed attributes: ",
          StrJoin(attrs, ", ",
                  [&](string* out, const std::pair<string, AttrConfig>& kv) {
                    StrAppend(out, kv.first);
                  }));
    }
    return Error(loc, StrFormat("unexpected attribute \"%s\".  %s", name,
                                allowed_attrs));
  }
  AttrTy attr_type = attr_it->second.attr_type;
  void* attr_out_ptr = attr_it->second.result;
  bool success = [&] {
    LocTy attr_loc = lexer_.GetLoc();
    switch (attr_type) {
      case AttrTy::kBool: {
        bool result;
        if (!ParseBool(&result)) {
          return false;
        }
        static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kInt64: {
        int64 result;
        if (!ParseInt64(&result)) {
          return false;
        }
        static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kInt32: {
        int64 result;
        if (!ParseInt64(&result)) {
          return false;
        }
        if (result != static_cast<int32>(result)) {
          return Error(attr_loc, "value out of range for int32");
        }
        static_cast<optional<int32>*>(attr_out_ptr)
            ->emplace(static_cast<int32>(result));
        return true;
      }
      case AttrTy::kFloat: {
        double result;
        if (!ParseDouble(&result)) {
          return false;
        }
        if (result > std::numeric_limits<float>::max() ||
            result < std::numeric_limits<float>::lowest()) {
          return Error(attr_loc, "value out of range for float");
        }
        static_cast<optional<float>*>(attr_out_ptr)
            ->emplace(static_cast<float>(result));
        return true;
      }
      case AttrTy::kHloComputation: {
        HloComputation* result = nullptr;
        if (!ParseHloComputation(&result)) {
          return false;
        }
        static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kBracedHloComputationList: {
        std::vector<HloComputation*> result;
        if (!ParseHloComputationList(&result)) {
          return false;
        }
        static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kFftType: {
        FftType result;
        if (!ParseFftType(&result)) {
          return false;
        }
        static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kComparisonDirection: {
        ComparisonDirection result;
        if (!ParseComparisonDirection(&result)) {
          return false;
        }
        static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kWindow: {
        Window result;
        if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
          return false;
        }
        static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kConvolutionDimensionNumbers: {
        ConvolutionDimensionNumbers result;
        if (!ParseConvolutionDimensionNumbers(&result)) {
          return false;
        }
        static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kSharding: {
        OpSharding sharding;
        if (!ParseSharding(&sharding)) {
          return false;
        }
        static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
        return true;
      }
      case AttrTy::kParameterReplication: {
        ParameterReplication parameter_replication;
        if (!ParseParameterReplication(&parameter_replication)) {
          return false;
        }
        static_cast<optional<ParameterReplication>*>(attr_out_ptr)
            ->emplace(parameter_replication);
        return true;
      }
      case AttrTy::kInstructionList: {
        std::vector<HloInstruction*> result;
        if (!ParseInstructionNames(&result)) {
          return false;
        }
        static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kFusionKind: {
        HloInstruction::FusionKind result;
        if (!ParseFusionKind(&result)) {
          return false;
        }
        static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kBracedInt64List: {
        std::vector<int64> result;
        if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
                            &result)) {
          return false;
        }
        static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kBracedInt64ListList: {
        std::vector<std::vector<int64>> result;
        if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
                                TokKind::kComma, &result)) {
          return false;
        }
        static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kSliceRanges: {
        SliceRanges result;
        if (!ParseSliceRanges(&result)) {
          return false;
        }
        static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kPaddingConfig: {
        PaddingConfig result;
        if (!ParsePaddingConfig(&result)) {
          return false;
        }
        static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kString: {
        string result;
        if (!ParseString(&result)) {
          return false;
        }
        static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kMetadata: {
        OpMetadata result;
        if (!ParseMetadata(&result)) {
          return false;
        }
        static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
        return true;
      }
      case AttrTy::kDistribution: {
        RandomDistribution result;
        if (!ParseRandomDistribution(&result)) {
          return false;
        }
        static_cast<optional<RandomDistribution>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kDomain: {
        return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
      }
      case AttrTy::kPrecisionList: {
        std::vector<PrecisionConfig::Precision> result;
        if (!ParsePrecisionList(&result)) {
          return false;
        }
        static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
            attr_out_ptr)
            ->emplace(result);
        return true;
      }
      case AttrTy::kShapeList: {
        std::vector<Shape> result;
        if (!ParseShapeList(&result)) {
          return false;
        }
        static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
            ->emplace(result);
        return true;
      }
    }
  }();
  if (!success) {
    return Error(loc, StrFormat("error parsing attribute %s", name));
  }
  return true;
}