absl::Status MlPredictTableValuedFunctionEvaluator::Init()

in backend/query/ml/ml_predict_table_valued_function.cc [151:248]


absl::Status MlPredictTableValuedFunctionEvaluator::Init() {
  // Create index of input columns.
  CaseInsensitiveStringMap<std::vector<int64_t>> input_columns_by_name;
  for (int i = 0; i < input_->NumColumns(); ++i) {
    input_columns_by_name[input_->GetColumnName(i)].emplace_back(i);
  }

  // Validate that model inputs are satisfied and build model_inputs_.
  input_values_.resize(model_->NumInputs());
  for (int i = 0; i < model_->NumInputs(); ++i) {
    const QueryableModelColumn* model_column =
        model_->GetInput(i)->GetAs<QueryableModelColumn>();
    ZETASQL_RET_CHECK(model_column);

    // Find matching input column by name.
    auto input_column = input_columns_by_name.find(model_column->Name());
    if (input_column == input_columns_by_name.end()) {
      // If column is required, fail the query.
      if (model_column->required()) {
        return error::MlInputColumnMissing(
            model_column->Name(),
            model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
                                              /*use_external_float32=*/true));
      }
      // Ignore missing optional columns.
      continue;
    }

    // If there is more than one matching input column, raise ambiguous error.
    if (input_column->second.size() > 1) {
      return error::MlInputColumnAmbiguous(model_column->Name());
    }

    ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
    int64_t input_column_index = input_column->second.front();

    const zetasql::Type* input_column_type =
        input_->GetColumnType(input_column_index);
    if (!input_column_type->Equals(model_column->GetType())) {
      return error::MlInputColumnTypeMismatch(
          model_column->Name(),
          input_column_type->TypeName(zetasql::PRODUCT_EXTERNAL,
                                      /*use_external_float32=*/true),
          model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
                                            /*use_external_float32=*/true));
    }

    input_columns_.push_back(InputColumn{.input_index = input_column_index,
                                         .value = &input_values_[i]});

    model_inputs_.insert(
        {model_column->Name(),
         ModelEvaluator::ModelColumn{.model_column = model_column,
                                     .value = &input_values_[i]}});
  }

  // Map output columns to model outputs or passthrough columns.
  output_values_.resize(output_columns_.size());
  for (int i = 0; i < output_columns_.size(); ++i) {
    const std::string& column_name = output_columns_[i].name;
    const zetasql::Type* column_type = output_columns_[i].type;

    // Output of the model, not a pass through column.
    const zetasql::Column* model_column =
        model_->FindOutputByName(column_name);
    if (model_column != nullptr) {
      ZETASQL_RET_CHECK(model_column->Is<QueryableModelColumn>());
      ZETASQL_RET_CHECK(model_column->GetType()->Equals(column_type));
      model_outputs_.insert(
          {model_column->Name(),
           ModelEvaluator::ModelColumn{
               .model_column = model_column->GetAs<QueryableModelColumn>(),
               .value = &output_values_[i]}});
      continue;
    }

    // If the output column matches an input column, it's a pass-through column.
    auto input_column = input_columns_by_name.find(column_name);
    if (input_column != input_columns_by_name.end()) {
      if (input_column->second.size() > 1) {
        return error::MlPassThroughColumnAmbiguous(column_name);
      }
      ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
      int64_t input_column_index = input_column->second.front();
      const zetasql::Type* input_column_type =
          input_->GetColumnType(input_column_index);
      ZETASQL_RET_CHECK(column_type->Equals(input_column_type));
      input_columns_.push_back(InputColumn{.input_index = input_column_index,
                                           .value = &output_values_[i]});
      continue;
    }

    ZETASQL_RET_CHECK_FAIL() << "Could not match ML TVF Scan column " << column_name
                     << ". Matches should be ensured when resolving the TVF";
  }

  return absl::OkStatus();
}