arrow::Status StartCall()

in src/afs.cc [3970:4061]


	arrow::Status StartCall(
		const arrow::flight::CallInfo& info,
#if ARROW_VERSION_MAJOR >= 13
		const arrow::flight::ServerCallContext& context,
#else
		const arrow::flight::CallHeaders& incoming_headers,
#endif
		std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override
	{
		std::string databaseName("postgres");
#if ARROW_VERSION_MAJOR >= 13
		const auto& incomingHeaders = context.incoming_headers();
#else
		const auto& incomingHeaders = incoming_headers;
#endif
		auto databaseHeader = incomingHeaders.find("x-flight-sql-database");
		if (databaseHeader != incomingHeaders.end())
		{
			databaseName = databaseHeader->second;
		}
		auto authorizationHeader = incomingHeaders.find("authorization");
		if (authorizationHeader == incomingHeaders.end())
		{
			return arrow::flight::MakeFlightError(
				arrow::flight::FlightStatusCode::Unauthenticated,
				"No authorization header");
		}
		auto value = authorizationHeader->second;
		std::stringstream valueStream{std::string(value)};
		std::string type("");
		std::getline(valueStream, type, ' ');
		if (type == "Basic")
		{
			std::stringstream decodedStream(
				arrow::util::base64_decode(value.substr(valueStream.tellg())));
			std::string userName("");
			std::string password("");
			std::getline(decodedStream, userName, ':');
			std::getline(decodedStream, password);
#if ARROW_VERSION_MAJOR >= 13
			const auto& clientAddress = context.peer();
#else
			// 192.0.0.1 is one of reserved IPv4 addresses for documentation.
			std::string clientAddress("ipv4:192.0.2.1:2929");
#endif
			auto sessionIDResult =
				proxy_->connect(databaseName, userName, password, clientAddress);
			if (!sessionIDResult.status().ok())
			{
				return arrow::flight::MakeFlightError(
					arrow::flight::FlightStatusCode::Unauthenticated,
					sessionIDResult.status().ToString());
			}
			auto sessionID = *sessionIDResult;
			*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
			return arrow::Status::OK();
		}
		else if (type == "Bearer")
		{
			std::string sessionIDString(value.substr(valueStream.tellg()));
			if (sessionIDString.size() == 0)
			{
				return arrow::flight::MakeFlightError(
					arrow::flight::FlightStatusCode::Unauthorized,
					std::string("invalid Bearer token"));
			}
			auto start = sessionIDString.c_str();
			char* end = nullptr;
			uint64_t sessionID = std::strtoull(start, &end, 10);
			if (end[0] != '\0')
			{
				return arrow::flight::MakeFlightError(
					arrow::flight::FlightStatusCode::Unauthorized,
					std::string("invalid Bearer token"));
			}
			if (!proxy_->is_valid_session(sessionID))
			{
				return arrow::flight::MakeFlightError(
					arrow::flight::FlightStatusCode::Unauthorized,
					std::string("invalid Bearer token"));
			}
			*middleware = std::make_shared<HeaderAuthServerMiddleware>(sessionID);
			return arrow::Status::OK();
		}
		else
		{
			return arrow::flight::MakeFlightError(
				arrow::flight::FlightStatusCode::Unauthenticated,
				std::string("authorization header must use Basic or Bearer: <") + type +
					std::string(">"));
		}
	}