absl::Status QueryValidator::CheckHintValue()

in backend/query/query_validator.cc [286:417]


absl::Status QueryValidator::CheckHintValue(
    absl::string_view name, const zetasql::Value& value,
    const zetasql::ResolvedNodeKind node_kind,
    const absl::flat_hash_map<absl::string_view, zetasql::Value>& hint_map) {
  static const auto* supported_hint_types =
      new absl::flat_hash_map<absl::string_view, const zetasql::Type*,
                              zetasql_base::StringViewCaseHash, zetasql_base::StringViewCaseEqual>{{
          {kHintJoinTypeDeprecated, zetasql::types::StringType()},
          {kHintParameterSensitive, zetasql::types::StringType()},
          {kHintJoinMethod, zetasql::types::StringType()},
          {kHashJoinBuildSide, zetasql::types::StringType()},
          {kHashJoinExecution, zetasql::types::StringType()},
          {kHintJoinBatch, zetasql::types::BoolType()},
          {kHintJoinForceOrder, zetasql::types::BoolType()},
          {kHintGroupTypeDeprecated, zetasql::types::StringType()},
          {kHintGroupMethod, zetasql::types::StringType()},
          {kHintForceIndex, zetasql::types::StringType()},
          {kUseAdditionalParallelism, zetasql::types::BoolType()},
          {kHintLockScannedRange, zetasql::types::StringType()},
          {kHintConstantFolding, zetasql::types::BoolType()},
          {kHintTableScanGroupByScanOptimization, zetasql::types::BoolType()},
          {kHintEnableAdaptivePlans, zetasql::types::BoolType()},
          {kHintDisableInline, zetasql::types::BoolType()},
          {kHintIndexStrategy, zetasql::types::StringType()},
          {kHintAllowSearchIndexesInTransaction, zetasql::types::BoolType()},
          {kRequireEnhanceQuery, zetasql::types::BoolType()},
          {kEnhanceQueryTimeoutMs, zetasql::types::Int64Type()},
          {kScanMethod, zetasql::types::StringType()},
      }};

  const auto& iter = supported_hint_types->find(name);
  ZETASQL_RET_CHECK(iter != supported_hint_types->cend());
  if (!value.type()->Equals(iter->second)) {
    return error::InvalidHintValue(name, value.DebugString());
  }
  if (absl::EqualsIgnoreCase(name, kHintForceIndex)) {
    const std::string& index_name = value.string_value();
    bool base_table_hint = absl::EqualsIgnoreCase(index_name, kHintBaseTable);
    if (!base_table_hint) {
      // Statement-level FORCE_INDEX hints can only be '_BASE_TABLE'.
      if (node_kind == zetasql::RESOLVED_QUERY_STMT) {
        return error::InvalidStatementHintValue(name, value.DebugString());
      }
      const std::vector<const Index*> indexes =
          context_.schema->FindIndexesUnderName(index_name);
      if (indexes.empty()) {
        // We don't have the table name here. So this will not match prod error
        // message.
        return error::QueryHintIndexNotFound("", index_name);
      }
      for (const auto& index : indexes) {
        indexes_used_.insert(index);
      }
    }
  } else if (absl::EqualsIgnoreCase(name, kHintJoinMethod) ||
             absl::EqualsIgnoreCase(name, kHintJoinTypeDeprecated)) {
    const std::string& string_value = value.string_value();
    if (!(absl::EqualsIgnoreCase(string_value, kHintJoinTypeApply) ||
          absl::EqualsIgnoreCase(string_value, kHintJoinTypeHash) ||
          absl::EqualsIgnoreCase(string_value, kHintJoinTypeMerge) ||
          absl::EqualsIgnoreCase(string_value,
                                 kHintJoinTypePushBroadcastHashJoin) ||
          absl::EqualsIgnoreCase(string_value,
                                 kHintJoinTypeNestedLoopDeprecated))) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  } else if (absl::EqualsIgnoreCase(name, kHintParameterSensitive)) {
    const std::string& string_value = value.string_value();
    if (!(absl::EqualsIgnoreCase(string_value, kHintParameterSensitiveAlways) ||
          absl::EqualsIgnoreCase(string_value, kHintParameterSensitiveAuto) ||
          absl::EqualsIgnoreCase(string_value, kHintParameterSensitiveNever))) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  } else if (absl::EqualsIgnoreCase(name, kHashJoinBuildSide)) {
    bool is_hash_join = [&]() {
      auto it = hint_map.find(kHintJoinMethod);
      if (it != hint_map.end() &&
          absl::EqualsIgnoreCase(it->second.string_value(),
                                 kHintJoinTypeHash)) {
        return true;
      }
      it = hint_map.find(kHintJoinTypeDeprecated);
      if (it != hint_map.end() &&
          absl::EqualsIgnoreCase(it->second.string_value(),
                                 kHintJoinTypeHash)) {
        return true;
      }
      return false;
    }();
    if (!is_hash_join) {
      return error::InvalidHintForNode(kHashJoinBuildSide, "HASH joins");
    }
    const std::string& string_value = value.string_value();
    if (!(absl::EqualsIgnoreCase(string_value, kHashJoinBuildSideLeft) ||
          absl::EqualsIgnoreCase(string_value, kHashJoinBuildSideRight))) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  } else if (absl::EqualsIgnoreCase(name, kHashJoinExecution)) {
    const std::string& string_value = value.string_value();
    if (!(absl::EqualsIgnoreCase(string_value, kHashJoinExecutionOnePass) ||
          absl::EqualsIgnoreCase(string_value, kHashJoinExecutionMultiPass))) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  } else if (absl::EqualsIgnoreCase(name, kHintGroupMethod) ||
             absl::EqualsIgnoreCase(name, kHintGroupTypeDeprecated)) {
    const std::string& string_value = value.string_value();
    if (!(absl::EqualsIgnoreCase(string_value, kHintGroupMethodHash) ||
          absl::EqualsIgnoreCase(string_value, kHintGroupMethodStream))) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  } else if (absl::EqualsIgnoreCase(name, kHintLockScannedRange)) {
    const std::string& string_value = value.string_value();
    if (!(absl::EqualsIgnoreCase(string_value,
                                 kHintLockScannedRangeExclusive) ||
          absl::EqualsIgnoreCase(string_value, kHintLockScannedRangeShared))) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  } else if (absl::EqualsIgnoreCase(name, kHintIndexStrategy)) {
    const std::string& string_value = value.string_value();
    if (!absl::EqualsIgnoreCase(string_value,
                                kHintIndexStrategyForceIndexUnion)) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  } else if (absl::EqualsIgnoreCase(name, kScanMethod)) {
    const std::string& string_value = value.string_value();
    if (!(absl::EqualsIgnoreCase(string_value, kScanMethodBatch) ||
          absl::EqualsIgnoreCase(string_value, kScanMethodRow))) {
      return error::InvalidHintValue(name, value.DebugString());
    }
  }
  return absl::OkStatus();
}