backend/query/ml/ml_predict_table_valued_function.cc (256 lines of code) (raw):
//
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "backend/query/ml/ml_predict_table_valued_function.h"
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "zetasql/public/analyzer_options.h"
#include "zetasql/public/catalog.h"
#include "zetasql/public/evaluator_table_iterator.h"
#include "zetasql/public/function_signature.h"
#include "zetasql/public/table_valued_function.h"
#include "zetasql/public/types/type.h"
#include "zetasql/public/types/type_factory.h"
#include "zetasql/public/value.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/string_view.h"
#include "backend/common/case.h"
#include "backend/query/ml/model_evaluator.h"
#include "backend/query/queryable_model.h"
#include "common/errors.h"
#include "zetasql/base/ret_check.h"
#include "zetasql/base/status_macros.h"
namespace google::spanner::emulator::backend {
namespace {
constexpr static absl::string_view kSafe = "SAFE";
constexpr static absl::string_view kMlFunctionNamespace = "ML";
constexpr static absl::string_view kFunctionName = "PREDICT";
std::vector<std::string> FunctionName(bool safe) {
if (safe) {
return {std::string(kSafe), std::string(kMlFunctionNamespace),
std::string(kFunctionName)};
}
return {std::string(kMlFunctionNamespace), std::string(kFunctionName)};
}
class MlPredictTableValuedFunctionEvaluator
: public zetasql::EvaluatorTableIterator {
public:
MlPredictTableValuedFunctionEvaluator(
const zetasql::Model* model,
std::unique_ptr<EvaluatorTableIterator> input,
zetasql::Value parameters,
const std::vector<zetasql::TVFSchemaColumn>& output_columns)
: model_(model),
input_(std::move(input)),
parameters_(std::move(parameters)),
output_columns_(output_columns) {}
// Validates inputs and initializes evaluator's state.
absl::Status Init();
int NumColumns() const override {
return static_cast<int>(output_columns_.size());
}
std::string GetColumnName(int i) const override {
DCHECK_GE(i, 0);
DCHECK_LT(i, output_columns_.size());
return output_columns_[i].name;
}
const zetasql::Type* GetColumnType(int i) const override {
DCHECK_GE(i, 0);
DCHECK_LT(i, output_columns_.size());
return output_columns_[i].type;
}
const zetasql::Value& GetValue(int i) const override {
DCHECK_GE(i, 0);
DCHECK_LT(i, output_columns_.size());
return output_values_[i];
}
absl::Status Status() const override { return status_; }
absl::Status Cancel() override { return input_->Cancel(); }
bool NextRow() override {
// Advance input iterator, stop if there is an error.
if (!input_->NextRow()) {
status_ = input_->Status();
return false;
}
// Get all the input values and populate pass-through columns.
for (auto& input_column : input_columns_) {
*input_column.value = input_->GetValue(input_column.input_index);
}
// Invoke model evaluator to populate output values.
status_ = ModelEvaluator::Predict(model_, model_inputs_, model_outputs_);
return status_.ok();
}
private:
// The model argument of ML.PREDICT function.
const zetasql::Model* const model_;
// The relation argument of ML.PREDICT function.
std::unique_ptr<EvaluatorTableIterator> input_;
// The parameters argument of ML.PREDICT function.
const zetasql::Value parameters_;
// Selected output columns: model outputs and pass-through columns.
const std::vector<zetasql::TVFSchemaColumn> output_columns_;
// Maps input iterator column index to either input_values_ for model inputs
// or output_values_ for pass-through columns.
struct InputColumn {
// Index of the input column value to be read.
int64_t input_index;
// Pointer to the value to be set.
zetasql::Value* value;
};
std::vector<InputColumn> input_columns_;
// Model input columns sent as arguments to ModelEvaluator.
CaseInsensitiveStringMap<const ModelEvaluator::ModelColumn> model_inputs_;
// Model output columns values of which are set by ModelEvaluator.
CaseInsensitiveStringMap<ModelEvaluator::ModelColumn> model_outputs_;
// Vector of values referenced by model_inputs_.
std::vector<zetasql::Value> input_values_;
// Vector of values accessible through GetValue().
std::vector<zetasql::Value> output_values_;
// Status of the iterator.
absl::Status status_;
};
absl::Status MlPredictTableValuedFunctionEvaluator::Init() {
// Create index of input columns.
CaseInsensitiveStringMap<std::vector<int64_t>> input_columns_by_name;
for (int i = 0; i < input_->NumColumns(); ++i) {
input_columns_by_name[input_->GetColumnName(i)].emplace_back(i);
}
// Validate that model inputs are satisfied and build model_inputs_.
input_values_.resize(model_->NumInputs());
for (int i = 0; i < model_->NumInputs(); ++i) {
const QueryableModelColumn* model_column =
model_->GetInput(i)->GetAs<QueryableModelColumn>();
ZETASQL_RET_CHECK(model_column);
// Find matching input column by name.
auto input_column = input_columns_by_name.find(model_column->Name());
if (input_column == input_columns_by_name.end()) {
// If column is required, fail the query.
if (model_column->required()) {
return error::MlInputColumnMissing(
model_column->Name(),
model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true));
}
// Ignore missing optional columns.
continue;
}
// If there is more than one matching input column, raise ambiguous error.
if (input_column->second.size() > 1) {
return error::MlInputColumnAmbiguous(model_column->Name());
}
ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
int64_t input_column_index = input_column->second.front();
const zetasql::Type* input_column_type =
input_->GetColumnType(input_column_index);
if (!input_column_type->Equals(model_column->GetType())) {
return error::MlInputColumnTypeMismatch(
model_column->Name(),
input_column_type->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true),
model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true));
}
input_columns_.push_back(InputColumn{.input_index = input_column_index,
.value = &input_values_[i]});
model_inputs_.insert(
{model_column->Name(),
ModelEvaluator::ModelColumn{.model_column = model_column,
.value = &input_values_[i]}});
}
// Map output columns to model outputs or passthrough columns.
output_values_.resize(output_columns_.size());
for (int i = 0; i < output_columns_.size(); ++i) {
const std::string& column_name = output_columns_[i].name;
const zetasql::Type* column_type = output_columns_[i].type;
// Output of the model, not a pass through column.
const zetasql::Column* model_column =
model_->FindOutputByName(column_name);
if (model_column != nullptr) {
ZETASQL_RET_CHECK(model_column->Is<QueryableModelColumn>());
ZETASQL_RET_CHECK(model_column->GetType()->Equals(column_type));
model_outputs_.insert(
{model_column->Name(),
ModelEvaluator::ModelColumn{
.model_column = model_column->GetAs<QueryableModelColumn>(),
.value = &output_values_[i]}});
continue;
}
// If the output column matches an input column, it's a pass-through column.
auto input_column = input_columns_by_name.find(column_name);
if (input_column != input_columns_by_name.end()) {
if (input_column->second.size() > 1) {
return error::MlPassThroughColumnAmbiguous(column_name);
}
ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
int64_t input_column_index = input_column->second.front();
const zetasql::Type* input_column_type =
input_->GetColumnType(input_column_index);
ZETASQL_RET_CHECK(column_type->Equals(input_column_type));
input_columns_.push_back(InputColumn{.input_index = input_column_index,
.value = &output_values_[i]});
continue;
}
ZETASQL_RET_CHECK_FAIL() << "Could not match ML TVF Scan column " << column_name
<< ". Matches should be ensured when resolving the TVF";
}
return absl::OkStatus();
}
} // namespace
MlPredictTableValuedFunction::MlPredictTableValuedFunction(bool safe)
: zetasql::TableValuedFunction(
FunctionName(safe),
zetasql::FunctionSignature(
/*result_type=*/zetasql::FunctionArgumentType::AnyRelation(),
/*arguments=*/
{
zetasql::FunctionArgumentType::AnyModel(),
zetasql::FunctionArgumentType::AnyRelation(),
{
/*kind=*/zetasql::ARG_STRUCT_ANY,
/*options=*/
zetasql::FunctionArgumentTypeOptions(
zetasql::FunctionArgumentType::OPTIONAL),
},
},
/*context_ptr=*/nullptr)),
safe_(safe) {}
absl::Status MlPredictTableValuedFunction::Resolve(
const zetasql::AnalyzerOptions* analyzer_options,
const std::vector<zetasql::TVFInputArgumentType>& actual_arguments,
const zetasql::FunctionSignature& concrete_signature,
zetasql::Catalog* catalog, zetasql::TypeFactory* type_factory,
std::shared_ptr<zetasql::TVFSignature>* output_tvf_signature) const {
ZETASQL_RET_CHECK_GE(actual_arguments.size(), 2);
ZETASQL_RET_CHECK_LE(actual_arguments.size(), 3);
const zetasql::TVFInputArgumentType& model_argument = actual_arguments[0];
ZETASQL_RET_CHECK(model_argument.is_model());
ZETASQL_RET_CHECK_NE(model_argument.model().model(), nullptr);
const zetasql::Model& model = *model_argument.model().model();
const zetasql::TVFInputArgumentType& relation_argument =
actual_arguments[1];
ZETASQL_RET_CHECK(relation_argument.is_relation());
const zetasql::TVFRelation& relation = relation_argument.relation();
std::vector<zetasql::TVFRelation::Column> output_columns;
output_columns.reserve(model.NumOutputs() + relation.num_columns());
absl::flat_hash_set<std::string> model_output_column_names;
model_output_column_names.reserve(model.NumOutputs());
ZETASQL_RET_CHECK_GT(model.NumOutputs(), 0);
for (int i = 0; i < model.NumOutputs(); ++i) {
const zetasql::Column* model_column = model.GetOutput(i);
ZETASQL_RET_CHECK_NE(model_column, nullptr);
output_columns.emplace_back(model_column->Name(), model_column->GetType(),
false);
model_output_column_names.emplace(
absl::AsciiStrToLower(model_column->Name()));
}
for (const zetasql::TVFSchemaColumn& relation_column : relation.columns()) {
if (!model_output_column_names.contains(
absl::AsciiStrToLower(relation_column.name))) {
output_columns.push_back(relation_column);
}
}
ZETASQL_RET_CHECK_NE(output_tvf_signature, nullptr);
*output_tvf_signature = std::make_shared<zetasql::TVFSignature>(
actual_arguments, zetasql::TVFRelation(std::move(output_columns)));
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<zetasql::EvaluatorTableIterator>>
MlPredictTableValuedFunction::CreateEvaluator(
std::vector<TvfEvaluatorArg> input_arguments,
const std::vector<zetasql::TVFSchemaColumn>& output_columns,
const zetasql::FunctionSignature* function_call_signature) const {
ZETASQL_RET_CHECK_GE(input_arguments.size(), 2);
ZETASQL_RET_CHECK_LE(input_arguments.size(), 3);
ZETASQL_RET_CHECK(input_arguments[0].model);
const zetasql::Model* model = input_arguments[0].model;
ZETASQL_RET_CHECK(input_arguments[1].relation);
std::unique_ptr<zetasql::EvaluatorTableIterator> input =
std::move(input_arguments[1].relation);
zetasql::Value parameters;
if (input_arguments.size() >= 3) {
ZETASQL_RET_CHECK(input_arguments[2].value);
parameters = *input_arguments[2].value;
}
auto evaluator = std::make_unique<MlPredictTableValuedFunctionEvaluator>(
model, std::move(input), parameters, std::move(output_columns));
ZETASQL_RETURN_IF_ERROR(evaluator->Init());
return std::move(evaluator);
}
} // namespace google::spanner::emulator::backend