Sources/GoogleAI/GenerativeModel.swift (217 lines of code) (raw):

// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. import Foundation /// A type that represents a remote multimodal model (like Gemini), with the ability to generate /// content based on various input types. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) public final class GenerativeModel { // The prefix for a model resource in the Gemini API. private static let modelResourcePrefix = "models/" /// The resource name of the model in the backend; has the format "models/model-name". let modelResourceName: String /// The backing service responsible for sending and receiving model requests to the backend. let generativeAIService: GenerativeAIService /// Configuration parameters used for the MultiModalModel. let generationConfig: GenerationConfig? /// The safety settings to be used for prompts. let safetySettings: [SafetySetting]? /// A list of tools the model may use to generate the next response. let tools: [Tool]? /// Tool configuration for any `Tool` specified in the request. let toolConfig: ToolConfig? /// Instructions that direct the model to behave a certain way. let systemInstruction: ModelContent? /// Configuration parameters for sending requests to the backend. let requestOptions: RequestOptions /// Initializes a new remote model with the given parameters. /// /// - Parameters: /// - name: The name of the model to use, for example `"gemini-1.5-pro-latest"`; see /// [Gemini models](https://ai.google.dev/models/gemini) for a list of supported model names. /// - apiKey: The API key for your project. /// - generationConfig: The content generation parameters your model should use. /// - safetySettings: A value describing what types of harmful content your model should allow. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response. /// - systemInstruction: Instructions that direct the model to behave a certain way; currently /// only text content is supported, for example /// `ModelContent(role: "system", parts: "You are a cat. Your name is Neko.")`. /// - toolConfig: Tool configuration for any `Tool` specified in the request. /// - requestOptions Configuration parameters for sending requests to the backend. public convenience init(name: String, apiKey: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, tools: [Tool]? = nil, toolConfig: ToolConfig? = nil, systemInstruction: ModelContent? = nil, requestOptions: RequestOptions = RequestOptions()) { self.init( name: name, apiKey: apiKey, generationConfig: generationConfig, safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, systemInstruction: systemInstruction, requestOptions: requestOptions, urlSession: .shared ) } /// Initializes a new remote model with the given parameters. /// /// - Parameters: /// - name: The name of the model to use, e.g., `"gemini-1.5-pro-latest"`; see /// [Gemini models](https://ai.google.dev/models/gemini) for a list of supported model names. /// - apiKey: The API key for your project. /// - generationConfig: The content generation parameters your model should use. /// - safetySettings: A value describing what types of harmful content your model should allow. /// - tools: A list of ``Tool`` objects that the model may use to generate the next response. /// - systemInstruction: Instructions that direct the model to behave a certain way; currently /// only text content is supported, e.g., "You are a cat. Your name is Neko." /// - toolConfig: Tool configuration for any `Tool` specified in the request. /// - requestOptions Configuration parameters for sending requests to the backend. public convenience init(name: String, apiKey: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, tools: [Tool]? = nil, toolConfig: ToolConfig? = nil, systemInstruction: String..., requestOptions: RequestOptions = RequestOptions()) { self.init( name: name, apiKey: apiKey, generationConfig: generationConfig, safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, systemInstruction: ModelContent( role: "system", parts: systemInstruction.map { ModelContent.Part.text($0) } ), requestOptions: requestOptions, urlSession: .shared ) } /// The designated initializer for this class. init(name: String, apiKey: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, tools: [Tool]? = nil, toolConfig: ToolConfig? = nil, systemInstruction: ModelContent? = nil, requestOptions: RequestOptions = RequestOptions(), urlSession: URLSession) { modelResourceName = GenerativeModel.modelResourceName(name: name) generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession) self.generationConfig = generationConfig self.safetySettings = safetySettings self.tools = tools self.toolConfig = toolConfig self.systemInstruction = systemInstruction self.requestOptions = requestOptions Logging.default.info(""" [GoogleGenerativeAI] Model \( name, privacy: .public ) initialized. To enable additional logging, add \ `\(Logging.enableArgumentKey, privacy: .public)` as a launch argument in Xcode. """) Logging.verbose.debug("[GoogleGenerativeAI] Verbose logging enabled.") } /// Generates content from String and/or image inputs, given to the model as a prompt, that are /// representable as one or more ``ModelContent/Part``s. /// /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating /// content from /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting) /// or "direct" prompts. For /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) /// prompts, see `generateContent(_ content: @autoclosure () throws -> [ModelContent])`. /// /// - Parameter content: The input(s) given to the model as a prompt (see /// ``ThrowingPartsRepresentable`` /// for conforming types). /// - Returns: The content generated by the model. /// - Throws: A ``GenerateContentError`` if the request failed. public func generateContent(_ parts: any ThrowingPartsRepresentable...) async throws -> GenerateContentResponse { return try await generateContent([ModelContent(parts: parts)]) } /// Generates new content from input content given to the model as a prompt. /// /// - Parameter content: The input(s) given to the model as a prompt. /// - Returns: The generated content response from the model. /// - Throws: A ``GenerateContentError`` if the request failed. public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws -> GenerateContentResponse { let response: GenerateContentResponse do { let generateContentRequest = try GenerateContentRequest(model: modelResourceName, contents: content(), generationConfig: generationConfig, safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, systemInstruction: systemInstruction, isStreaming: false, options: requestOptions) response = try await generativeAIService.loadRequest(request: generateContentRequest) } catch { if let imageError = error as? ImageConversionError { throw GenerateContentError.promptImageContentError(underlying: imageError) } throw GenerativeModel.generateContentError(from: error) } // Check the prompt feedback to see if the prompt was blocked. if response.promptFeedback?.blockReason != nil { throw GenerateContentError.promptBlocked(response: response) } // Check to see if an error should be thrown for stop reason. if let reason = response.candidates.first?.finishReason, reason != .stop { throw GenerateContentError.responseStoppedEarly(reason: reason, response: response) } return response } /// Generates content from String and/or image inputs, given to the model as a prompt, that are /// representable as one or more ``ModelContent/Part``s. /// /// Since ``ModelContent/Part``s do not specify a role, this method is intended for generating /// content from /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting) /// or "direct" prompts. For /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) /// prompts, see `generateContent(_ content: @autoclosure () throws -> [ModelContent])`. /// /// - Parameter content: The input(s) given to the model as a prompt (see /// ``ThrowingPartsRepresentable`` /// for conforming types). /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError`` /// error if an error occurred. @available(macOS 12.0, *) public func generateContentStream(_ parts: any ThrowingPartsRepresentable...) -> AsyncThrowingStream<GenerateContentResponse, Error> { return try generateContentStream([ModelContent(parts: parts)]) } /// Generates new content from input content given to the model as a prompt. /// /// - Parameter content: The input(s) given to the model as a prompt. /// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError`` /// error if an error occurred. @available(macOS 12.0, *) public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent]) -> AsyncThrowingStream<GenerateContentResponse, Error> { let evaluatedContent: [ModelContent] do { evaluatedContent = try content() } catch let underlying { return AsyncThrowingStream { continuation in let error: Error if let contentError = underlying as? ImageConversionError { error = GenerateContentError.promptImageContentError(underlying: contentError) } else { error = GenerateContentError.internalError(underlying: underlying) } continuation.finish(throwing: error) } } let generateContentRequest = GenerateContentRequest(model: modelResourceName, contents: evaluatedContent, generationConfig: generationConfig, safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, systemInstruction: systemInstruction, isStreaming: true, options: requestOptions) var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) .makeAsyncIterator() return AsyncThrowingStream { let response: GenerateContentResponse? do { response = try await responseIterator.next() } catch { throw GenerativeModel.generateContentError(from: error) } // The responseIterator will return `nil` when it's done. guard let response = response else { // This is the end of the stream! Signal it by sending `nil`. return nil } // Check the prompt feedback to see if the prompt was blocked. if response.promptFeedback?.blockReason != nil { throw GenerateContentError.promptBlocked(response: response) } // If the stream ended early unexpectedly, throw an error. if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response) } else { // Response was valid content, pass it along and continue. return response } } } /// Creates a new chat conversation using this model with the provided history. public func startChat(history: [ModelContent] = []) -> Chat { return Chat(model: self, history: history) } /// Runs the model's tokenizer on String and/or image inputs that are representable as one or more /// ``ModelContent/Part``s. /// /// Since ``ModelContent/Part``s do not specify a role, this method is intended for tokenizing /// [zero-shot](https://developers.google.com/machine-learning/glossary/generative#zero-shot-prompting) /// or "direct" prompts. For /// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting) /// input, see `countTokens(_ content: @autoclosure () throws -> [ModelContent])`. /// /// - Parameter content: The input(s) given to the model as a prompt (see /// ``ThrowingPartsRepresentable`` /// for conforming types). /// - Returns: The results of running the model's tokenizer on the input; contains /// ``CountTokensResponse/totalTokens``. /// - Throws: A ``CountTokensError`` if the tokenization request failed. public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws -> CountTokensResponse { return try await countTokens([ModelContent(parts: parts)]) } /// Runs the model's tokenizer on the input content and returns the token count. /// /// - Parameter content: The input given to the model as a prompt. /// - Returns: The results of running the model's tokenizer on the input; contains /// ``CountTokensResponse/totalTokens``. /// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was /// invalid. public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws -> CountTokensResponse { do { let generateContentRequest = try GenerateContentRequest(model: modelResourceName, contents: content(), generationConfig: generationConfig, safetySettings: safetySettings, tools: tools, toolConfig: toolConfig, systemInstruction: systemInstruction, isStreaming: false, options: requestOptions) let countTokensRequest = CountTokensRequest( model: modelResourceName, generateContentRequest: generateContentRequest, options: requestOptions ) return try await generativeAIService.loadRequest(request: countTokensRequest) } catch { throw CountTokensError.internalError(underlying: error) } } /// Returns a model resource name of the form "models/model-name" based on `name`. private static func modelResourceName(name: String) -> String { if name.contains("/") { return name } else { return modelResourcePrefix + name } } /// Returns a `GenerateContentError` (for public consumption) from an internal error. /// /// If `error` is already a `GenerateContentError` the error is returned unchanged. private static func generateContentError(from error: Error) -> GenerateContentError { if let error = error as? GenerateContentError { return error } else if let error = error as? RPCError, error.isInvalidAPIKeyError() { return GenerateContentError.invalidAPIKey(message: error.message) } else if let error = error as? RPCError, error.isUnsupportedUserLocationError() { return GenerateContentError.unsupportedUserLocation } return GenerateContentError.internalError(underlying: error) } } /// An error thrown in `GenerativeModel.countTokens(_:)`. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) public enum CountTokensError: Error { case internalError(underlying: Error) }