in backend/query/ml/ml_predict_table_valued_function.cc [151:248]
absl::Status MlPredictTableValuedFunctionEvaluator::Init() {
// Create index of input columns.
CaseInsensitiveStringMap<std::vector<int64_t>> input_columns_by_name;
for (int i = 0; i < input_->NumColumns(); ++i) {
input_columns_by_name[input_->GetColumnName(i)].emplace_back(i);
}
// Validate that model inputs are satisfied and build model_inputs_.
input_values_.resize(model_->NumInputs());
for (int i = 0; i < model_->NumInputs(); ++i) {
const QueryableModelColumn* model_column =
model_->GetInput(i)->GetAs<QueryableModelColumn>();
ZETASQL_RET_CHECK(model_column);
// Find matching input column by name.
auto input_column = input_columns_by_name.find(model_column->Name());
if (input_column == input_columns_by_name.end()) {
// If column is required, fail the query.
if (model_column->required()) {
return error::MlInputColumnMissing(
model_column->Name(),
model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true));
}
// Ignore missing optional columns.
continue;
}
// If there is more than one matching input column, raise ambiguous error.
if (input_column->second.size() > 1) {
return error::MlInputColumnAmbiguous(model_column->Name());
}
ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
int64_t input_column_index = input_column->second.front();
const zetasql::Type* input_column_type =
input_->GetColumnType(input_column_index);
if (!input_column_type->Equals(model_column->GetType())) {
return error::MlInputColumnTypeMismatch(
model_column->Name(),
input_column_type->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true),
model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true));
}
input_columns_.push_back(InputColumn{.input_index = input_column_index,
.value = &input_values_[i]});
model_inputs_.insert(
{model_column->Name(),
ModelEvaluator::ModelColumn{.model_column = model_column,
.value = &input_values_[i]}});
}
// Map output columns to model outputs or passthrough columns.
output_values_.resize(output_columns_.size());
for (int i = 0; i < output_columns_.size(); ++i) {
const std::string& column_name = output_columns_[i].name;
const zetasql::Type* column_type = output_columns_[i].type;
// Output of the model, not a pass through column.
const zetasql::Column* model_column =
model_->FindOutputByName(column_name);
if (model_column != nullptr) {
ZETASQL_RET_CHECK(model_column->Is<QueryableModelColumn>());
ZETASQL_RET_CHECK(model_column->GetType()->Equals(column_type));
model_outputs_.insert(
{model_column->Name(),
ModelEvaluator::ModelColumn{
.model_column = model_column->GetAs<QueryableModelColumn>(),
.value = &output_values_[i]}});
continue;
}
// If the output column matches an input column, it's a pass-through column.
auto input_column = input_columns_by_name.find(column_name);
if (input_column != input_columns_by_name.end()) {
if (input_column->second.size() > 1) {
return error::MlPassThroughColumnAmbiguous(column_name);
}
ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
int64_t input_column_index = input_column->second.front();
const zetasql::Type* input_column_type =
input_->GetColumnType(input_column_index);
ZETASQL_RET_CHECK(column_type->Equals(input_column_type));
input_columns_.push_back(InputColumn{.input_index = input_column_index,
.value = &output_values_[i]});
continue;
}
ZETASQL_RET_CHECK_FAIL() << "Could not match ML TVF Scan column " << column_name
<< ". Matches should be ensured when resolving the TVF";
}
return absl::OkStatus();
}