Sources/Generation/Generation.swift (82 lines of code) (raw):
//
// Generation.swift
//
//
// Created by Pedro Cuenca on 7/5/23.
//
import CoreML
import TensorUtils
import Tokenizers
public enum GenerationMode {
case contrastiveSearch
case greedy
case sample
case beam
case groupBeam
case unsupported
}
public typealias InputTokens = [Int]
public typealias GenerationOutput = [Int]
/// A callable (a model, usually), that predicts the next token after a given sequence
public typealias NextTokenModel = (InputTokens, GenerationConfig) -> any MLShapedArrayProtocol
public typealias PredictionTokensCallback = (GenerationOutput) -> Void
public typealias PredictionStringCallback = (String) -> Void
// TODO: callbacks (for streaming)
public protocol Generation {
func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput
func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String
}
public extension Generation {
func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
// Iterate until we find the eos token or reach the max length
// TODO: additional stopping criteria
var outputTokens = tokens
while outputTokens.count < config.maxLength {
let logits = model(outputTokens, config)
let (nextToken, _) = Math.argmax(logits)
if nextToken == config.eosTokenId { break }
outputTokens.append(nextToken)
callback?(outputTokens)
}
return outputTokens
}
/// https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552
func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
// Iterate until we find the eos token or reach the max length
// TODO: additional stopping criteria
var outputTokens = tokens
let logitsProcessor = LogitsProcessor(logitsWarpers: logitsWarpers(config: config))
while outputTokens.count < config.maxLength {
let outputs = model(outputTokens, config)
// `floats` can be much faster than `scalars` for a vector with stride 1, as it uses `memcpy` in that case
let logits = (outputs as? MLShapedArraySlice<Float>)?.floats ?? outputs.scalars as! [Float]
let (indexes, processedLogits) = logitsProcessor(logits)
let nextToken = Math.sample(indexes: indexes, probs: Math.softmax(processedLogits))
if nextToken == config.eosTokenId { break }
outputTokens.append(nextToken)
callback?(outputTokens)
}
return outputTokens
}
func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback? = nil) async -> String {
let tokens = tokenizer.encode(text: prompt)
var generationConfig = config
generationConfig.maxLength = config.maxNewTokens + tokens.count
let output: GenerationOutput
switch generationConfig.generationMode {
case .greedy:
output = await greedySearch(config: generationConfig, tokens: tokens, model: model) { tokens in
callback?(tokenizer.decode(tokens: tokens))
}
case .sample:
output = await sample(config: generationConfig, tokens: tokens, model: model) { tokens in
callback?(tokenizer.decode(tokens: tokens))
}
default:
fatalError("Generation mode \(generationConfig.generationMode) not implemented yet")
}
return tokenizer.decode(tokens: output)
}
private func logitsWarpers(config: GenerationConfig) -> [any LogitsWarper] {
var logitsWarpers = [any LogitsWarper]()
if config.temperature > 0, config.temperature != 1 {
logitsWarpers.append(TemperatureLogitsWarper(temperature: Float(config.temperature)))
}
if config.topK > 0 {
logitsWarpers.append(TopKLogitsWarper(k: config.topK))
}
if config.topP < 1.0 {
logitsWarpers.append(TopPLogitsWarper(p: Float(config.topP)))
}
if config.repetitionPenalty != 1.0 {
logitsWarpers.append(RepetitionPenaltyWarper(penalty: config.repetitionPenalty))
}
return logitsWarpers
}
}