in be/src/exprs/ai-functions.cc [178:348]
StringVal AiFunctions::AiGenerateTextInternal(FunctionContext* ctx,
const std::string_view& endpoint_sv, const StringVal& prompt, const StringVal& model,
const StringVal& auth_credential, const StringVal& platform_params,
const StringVal& impala_options, const bool dry_run) {
// Generate the header for the POST request
vector<string> headers;
headers.emplace_back(OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER);
string authHeader;
AiFunctions::AiFunctionsOptions ai_options;
Document impala_options_document;
if (!fastpath) {
try {
ParseImpalaOptions(impala_options, impala_options_document, ai_options);
} catch (const std::runtime_error& e) {
std::stringstream ss;
ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": " << e.what();
LOG(WARNING) << ss.str();
const Status err_status(ss.str());
RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
}
}
if (!fastpath && auth_credential.ptr != nullptr && auth_credential.len != 0) {
if (ai_options.credential_type == CREDENTIAL_TYPE::PLAIN) {
// Use the credential as a plain text token.
std::string_view token(
reinterpret_cast<char*>(auth_credential.ptr), auth_credential.len);
RETURN_STRINGVAL_IF_ERROR(
ctx, getAuthorizationHeader<platform>(authHeader, token, ai_options));
} else {
DCHECK(ai_options.credential_type == CREDENTIAL_TYPE::JCEKS);
// Use the credential as JCEKS secret and fetch API key.
string api_key;
string api_key_secret(
reinterpret_cast<char*>(auth_credential.ptr), auth_credential.len);
RETURN_STRINGVAL_IF_ERROR(ctx,
ExecEnv::GetInstance()->frontend()->GetSecretFromKeyStore(
api_key_secret, &api_key));
RETURN_STRINGVAL_IF_ERROR(
ctx, getAuthorizationHeader<platform>(authHeader, api_key, ai_options));
}
} else {
RETURN_STRINGVAL_IF_ERROR(
ctx, getAuthorizationHeader<platform>(authHeader, ai_api_key_, ai_options));
}
headers.emplace_back(authHeader);
string payload_str;
if (!fastpath && !ai_options.ai_custom_payload.empty()) {
payload_str =
string(ai_options.ai_custom_payload.data(), ai_options.ai_custom_payload.size());
} else {
// Generate the payload for the POST request
Document payload;
payload.SetObject();
Document::AllocatorType& payload_allocator = payload.GetAllocator();
// Azure Open AI endpoint doesn't expect model as a separate param since it's
// embedded in the endpoint. The 'deployment_name' below maps to a model.
// https://<resource_name>.openai.azure.com/openai/deployments/<deployment_name>/..
if (platform != AI_PLATFORM::AZURE_OPEN_AI) {
if (!fastpath && model.ptr != nullptr && model.len != 0) {
payload.AddMember("model",
rapidjson::StringRef(reinterpret_cast<char*>(model.ptr), model.len),
payload_allocator);
} else {
payload.AddMember("model",
rapidjson::StringRef(FLAGS_ai_model.c_str(), FLAGS_ai_model.length()),
payload_allocator);
}
}
Value message_array(rapidjson::kArrayType);
Value message(rapidjson::kObjectType);
message.AddMember("role", "user", payload_allocator);
if (prompt.ptr == nullptr || prompt.len == 0) {
// Return a string with the invalid prompt error message instead of failing
// the query, as the issue may be with the row rather than the configuration
// or query. This behavior might be reconsidered later.
return StringVal(AI_GENERATE_TXT_INVALID_PROMPT_ERROR.c_str());
}
message.AddMember("content",
rapidjson::StringRef(reinterpret_cast<char*>(prompt.ptr), prompt.len),
payload_allocator);
message_array.PushBack(message, payload_allocator);
payload.AddMember("messages", message_array, payload_allocator);
// Override additional platform params.
// Caution: 'payload' might reference data owned by 'overrides'.
// To ensure valid access, place 'overrides' outside the 'if'
// statement before using 'payload'.
Document overrides;
if (!fastpath && platform_params.ptr != nullptr && platform_params.len != 0) {
overrides.Parse(reinterpret_cast<char*>(platform_params.ptr), platform_params.len);
if (overrides.HasParseError()) {
std::stringstream ss;
ss << AI_GENERATE_TXT_JSON_PARSE_ERROR << ": error code "
<< overrides.GetParseError() << ", offset input "
<< overrides.GetErrorOffset();
LOG(WARNING) << ss.str();
const Status err_status(ss.str());
RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
}
for (auto& m : overrides.GetObject()) {
if (payload.HasMember(m.name.GetString())) {
if (m.name == "messages") {
const string error_msg = AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR
+ ": 'messages' is constructed from 'prompt', cannot be overridden";
LOG(WARNING) << error_msg;
const Status err_status(error_msg);
RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
} else {
payload[m.name.GetString()] = m.value;
}
} else {
if ((m.name == "n") && !(m.value.IsInt() && m.value.GetInt() == 1)) {
const string error_msg = AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR
+ ": 'n' must be of integer type and have value 1";
LOG(WARNING) << error_msg;
const Status err_status(error_msg);
RETURN_STRINGVAL_IF_ERROR(ctx, err_status);
}
payload.AddMember(m.name, m.value, payload_allocator);
}
}
}
// Convert payload into string for POST request
StringBuffer buffer;
Writer<StringBuffer> writer(buffer);
payload.Accept(writer);
payload_str = string(buffer.GetString(), buffer.GetSize());
}
DCHECK(!payload_str.empty());
VLOG(2) << "AI Generate Text: \nendpoint: " << endpoint_sv
<< " \npayload: " << payload_str;
if (UNLIKELY(dry_run)) {
std::stringstream post_request;
post_request << endpoint_sv;
for (auto& header : headers) {
post_request << "\n" << header;
}
post_request << "\n" << payload_str;
return StringVal::CopyFrom(ctx,
reinterpret_cast<const uint8_t*>(post_request.str().data()),
post_request.str().length());
}
// Send request to external AI API endpoint
kudu::EasyCurl curl;
curl.set_timeout(kudu::MonoDelta::FromSeconds(FLAGS_ai_connection_timeout_s));
curl.set_fail_on_http_error(true);
kudu::faststring resp;
kudu::Status status;
if (fastpath) {
DCHECK_EQ(std::string_view(FLAGS_ai_endpoint), endpoint_sv);
status = curl.PostToURL(FLAGS_ai_endpoint, payload_str, &resp, headers);
} else {
std::string endpoint_str{endpoint_sv};
status = curl.PostToURL(endpoint_str, payload_str, &resp, headers);
}
VLOG(2) << "AI Generate Text: \noriginal response: " << resp.ToString();
if (UNLIKELY(!status.ok())) {
SET_ERROR(ctx, status.ToString(), AI_GENERATE_TXT_COMMON_ERROR_PREFIX);
return StringVal::null();
}
// Parse the JSON response string
std::string response = AiGenerateTextParseOpenAiResponse(
std::string_view(reinterpret_cast<char*>(resp.data()), resp.size()));
VLOG(2) << "AI Generate Text: \nresponse: " << response;
StringVal result(ctx, response.length());
if (UNLIKELY(result.is_null)) return StringVal::null();
memcpy(result.ptr, response.data(), response.length());
return result;
}