bool SubstraitToVeloxPlanValidator::validateScalarFunction()

in cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc [185:247]


bool SubstraitToVeloxPlanValidator::validateScalarFunction(
    const ::substrait::Expression::ScalarFunction& scalarFunction,
    const RowTypePtr& inputType) {
  std::vector<core::TypedExprPtr> params;
  params.reserve(scalarFunction.arguments().size());
  for (const auto& argument : scalarFunction.arguments()) {
    if (argument.has_value() && !validateExpression(argument.value(), inputType)) {
      return false;
    }
    params.emplace_back(exprConverter_->toVeloxExpr(argument.value(), inputType));
  }

  const auto& function =
      SubstraitParser::findFunctionSpec(planConverter_.getFunctionMap(), scalarFunction.function_reference());
  const auto& name = SubstraitParser::getNameBeforeDelimiter(function);
  std::vector<std::string> types = SubstraitParser::getSubFunctionTypes(function);

  if (name == "round") {
    return validateRound(scalarFunction, inputType);
  } else if (name == "extract") {
    return validateExtractExpr(params);
  } else if (name == "char_length") {
    VELOX_CHECK(types.size() == 1);
    if (types[0] == "vbin") {
      LOG_VALIDATION_MSG("Binary type is not supported in " + name);
      return false;
    }
  } else if (name == "map_from_arrays") {
    LOG_VALIDATION_MSG("map_from_arrays is not supported.");
    return false;
  } else if (name == "get_array_item") {
    LOG_VALIDATION_MSG("get_array_item is not supported.");
    return false;
  } else if (name == "concat") {
    for (const auto& type : types) {
      if (type.find("struct") != std::string::npos || type.find("map") != std::string::npos ||
          type.find("list") != std::string::npos) {
        LOG_VALIDATION_MSG(type + " is not supported in concat.");
        return false;
      }
    }
  } else if (name == "murmur3hash") {
    for (const auto& type : types) {
      if (type.find("struct") != std::string::npos || type.find("map") != std::string::npos ||
          type.find("list") != std::string::npos) {
        LOG_VALIDATION_MSG(type + " is not supported in murmur3hash.");
        return false;
      }
    }
  }

  // Validate regex functions.
  if (kRegexFunctions.find(name) != kRegexFunctions.end()) {
    return validateRegexExpr(name, scalarFunction);
  }

  if (kBlackList.find(name) != kBlackList.end()) {
    LOG_VALIDATION_MSG("Function is not supported: " + name);
    return false;
  }

  return true;
}