in src/afs.cc [3970:4061]
arrow::Status StartCall(
const arrow::flight::CallInfo& info,
#if ARROW_VERSION_MAJOR >= 13
const arrow::flight::ServerCallContext& context,
#else
const arrow::flight::CallHeaders& incoming_headers,
#endif
std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override
{
std::string databaseName("postgres");
#if ARROW_VERSION_MAJOR >= 13
const auto& incomingHeaders = context.incoming_headers();
#else
const auto& incomingHeaders = incoming_headers;
#endif
auto databaseHeader = incomingHeaders.find("x-flight-sql-database");
if (databaseHeader != incomingHeaders.end())
{
databaseName = databaseHeader->second;
}
auto authorizationHeader = incomingHeaders.find("authorization");
if (authorizationHeader == incomingHeaders.end())
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
"No authorization header");
}
auto value = authorizationHeader->second;
std::stringstream valueStream{std::string(value)};
std::string type("");
std::getline(valueStream, type, ' ');
if (type == "Basic")
{
std::stringstream decodedStream(
arrow::util::base64_decode(value.substr(valueStream.tellg())));
std::string userName("");
std::string password("");
std::getline(decodedStream, userName, ':');
std::getline(decodedStream, password);
#if ARROW_VERSION_MAJOR >= 13
const auto& clientAddress = context.peer();
#else
// 192.0.0.1 is one of reserved IPv4 addresses for documentation.
std::string clientAddress("ipv4:192.0.2.1:2929");
#endif
auto sessionIDResult =
proxy_->connect(databaseName, userName, password, clientAddress);
if (!sessionIDResult.status().ok())
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
sessionIDResult.status().ToString());
}
auto sessionID = *sessionIDResult;
*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
return arrow::Status::OK();
}
else if (type == "Bearer")
{
std::string sessionIDString(value.substr(valueStream.tellg()));
if (sessionIDString.size() == 0)
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
auto start = sessionIDString.c_str();
char* end = nullptr;
uint64_t sessionID = std::strtoull(start, &end, 10);
if (end[0] != '\0')
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
if (!proxy_->is_valid_session(sessionID))
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthorized,
std::string("invalid Bearer token"));
}
*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
return arrow::Status::OK();
}
else
{
return arrow::flight::MakeFlightError(
arrow::flight::FlightStatusCode::Unauthenticated,
std::string("authorization header must use Basic or Bearer: <") + type +
std::string(">"));
}
}