Sources/GoogleAI/GenerativeAIService.swift (192 lines of code) (raw):
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
struct GenerativeAIService {
/// Gives permission to talk to the backend.
private let apiKey: String
private let urlSession: URLSession
init(apiKey: String, urlSession: URLSession) {
self.apiKey = apiKey
self.urlSession = urlSession
}
func loadRequest<T: GenerativeAIRequest>(request: T) async throws -> T.Response {
let urlRequest = try urlRequest(request: request)
#if DEBUG
printCURLCommand(from: urlRequest)
#endif
let data: Data
let rawResponse: URLResponse
(data, rawResponse) = try await urlSession.data(for: urlRequest)
let response = try httpResponse(urlResponse: rawResponse)
// Verify the status code is 200
guard response.statusCode == 200 else {
Logging.network.error("[GoogleGenerativeAI] The server responded with an error: \(response)")
if let responseString = String(data: data, encoding: .utf8) {
Logging.default.error("[GoogleGenerativeAI] Response payload: \(responseString)")
}
throw parseError(responseData: data)
}
return try parseResponse(T.Response.self, from: data)
}
@available(macOS 12.0, *)
func loadRequestStream<T: GenerativeAIRequest>(request: T)
-> AsyncThrowingStream<T.Response, Error> {
return AsyncThrowingStream { continuation in
Task {
let urlRequest: URLRequest
do {
urlRequest = try self.urlRequest(request: request)
} catch {
continuation.finish(throwing: error)
return
}
#if DEBUG
printCURLCommand(from: urlRequest)
#endif
let stream: URLSession.AsyncBytes
let rawResponse: URLResponse
do {
(stream, rawResponse) = try await urlSession.bytes(for: urlRequest)
} catch {
continuation.finish(throwing: error)
return
}
// Verify the status code is 200
let response: HTTPURLResponse
do {
response = try httpResponse(urlResponse: rawResponse)
} catch {
continuation.finish(throwing: error)
return
}
// Verify the status code is 200
guard response.statusCode == 200 else {
Logging.network
.error("[GoogleGenerativeAI] The server responded with an error: \(response)")
var responseBody = ""
for try await line in stream.lines {
responseBody += line + "\n"
}
Logging.default.error("[GoogleGenerativeAI] Response payload: \(responseBody)")
continuation.finish(throwing: parseError(responseBody: responseBody))
return
}
// Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
var extraLines: String = ""
let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
for try await line in stream.lines {
Logging.network.debug("[GoogleGenerativeAI] Stream response: \(line)")
if line.hasPrefix("data:") {
// We can assume 5 characters since it's utf-8 encoded, removing `data:`.
let jsonText = String(line.dropFirst(5))
let data: Data
do {
data = try jsonData(jsonText: jsonText)
} catch {
continuation.finish(throwing: error)
return
}
// Handle the content.
do {
let content = try parseResponse(T.Response.self, from: data)
continuation.yield(content)
} catch {
continuation.finish(throwing: error)
return
}
} else {
extraLines += line
}
}
if extraLines.count > 0 {
continuation.finish(throwing: parseError(responseBody: extraLines))
return
}
continuation.finish(throwing: nil)
}
}
}
// MARK: - Private Helpers
private func urlRequest<T: GenerativeAIRequest>(request: T) throws -> URLRequest {
var urlRequest = URLRequest(url: request.url)
urlRequest.httpMethod = "POST"
urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
urlRequest.setValue("genai-swift/\(GenerativeAISwift.version)",
forHTTPHeaderField: "x-goog-api-client")
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
urlRequest.httpBody = try encoder.encode(request)
urlRequest.timeoutInterval = request.options.timeout
return urlRequest
}
private func httpResponse(urlResponse: URLResponse) throws -> HTTPURLResponse {
// Verify the status code is 200
guard let response = urlResponse as? HTTPURLResponse else {
Logging.default
.error(
"[GoogleGenerativeAI] Response wasn't an HTTP response, internal error \(urlResponse)"
)
throw NSError(
domain: "com.google.generative-ai",
code: -1,
userInfo: [NSLocalizedDescriptionKey: "Response was not an HTTP response."]
)
}
return response
}
private func jsonData(jsonText: String) throws -> Data {
guard let data = jsonText.data(using: .utf8) else {
let error = NSError(
domain: "com.google.generative-ai",
code: -1,
userInfo: [NSLocalizedDescriptionKey: "Could not parse response as UTF8."]
)
throw error
}
return data
}
private func parseError(responseBody: String) -> Error {
do {
let data = try jsonData(jsonText: responseBody)
return parseError(responseData: data)
} catch {
return error
}
}
private func parseError(responseData: Data) -> Error {
do {
return try JSONDecoder().decode(RPCError.self, from: responseData)
} catch {
// TODO: Return an error about an unrecognized error payload with the response body
return error
}
}
private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
do {
return try JSONDecoder().decode(type, from: data)
} catch {
if let json = String(data: data, encoding: .utf8) {
Logging.network.error("[GoogleGenerativeAI] JSON response: \(json)")
}
Logging.default.error("[GoogleGenerativeAI] Error decoding server JSON: \(error)")
throw error
}
}
#if DEBUG
private func cURLCommand(from request: URLRequest) -> String {
var returnValue = "curl "
if let allHeaders = request.allHTTPHeaderFields {
for (key, value) in allHeaders {
returnValue += "-H '\(key): \(value)' "
}
}
guard let url = request.url else { return "" }
returnValue += "'\(url.absoluteString)' "
guard let body = request.httpBody,
let jsonStr = String(bytes: body, encoding: .utf8) else { return "" }
let escapedJSON = jsonStr.replacingOccurrences(of: "'", with: "'\\''")
returnValue += "-d '\(escapedJSON)'"
return returnValue
}
private func printCURLCommand(from request: URLRequest) {
let command = cURLCommand(from: request)
Logging.verbose.debug("""
[GoogleGenerativeAI] Creating request with the equivalent cURL command:
----- cURL command -----
\(command, privacy: .private)
------------------------
""")
}
#endif // DEBUG
}