Sources/Models/LanguageModelTypes.swift (26 lines of code) (raw):

// // LanguageModelTypes.swift // // // Created by Pedro Cuenca on 8/5/23. // import CoreML import Generation import Tokenizers public protocol LanguageModelProtocol { /// `name_or_path` in the Python world var modelName: String { get } var tokenizer: Tokenizer { get async throws } var model: MLModel { get } init(model: MLModel) /// Make prediction callable (this works like __call__ in Python) func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol } public extension LanguageModelProtocol { func callAsFunction(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol { predictNextTokenScores(tokens, config: config) } } public protocol TextGenerationModel: Generation, LanguageModelProtocol { var defaultGenerationConfig: GenerationConfig { get } func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback?) async throws -> String } public extension TextGenerationModel { @discardableResult func generate(config: GenerationConfig, prompt: String, callback: PredictionStringCallback? = nil) async throws -> String { try await generate(config: config, prompt: prompt, model: callAsFunction, tokenizer: tokenizer, callback: callback) } }