Sources/Generation/GenerationConfig.swift (43 lines of code) (raw):
//
// GenerationConfig.swift
//
//
// Created by Pedro Cuenca on 7/5/23.
//
import Foundation
/// Essentials taken from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/configuration_utils.py
public struct GenerationConfig {
public var maxLength = 20
public var maxNewTokens: Int
public var doSample = false
public var numBeams = 1
public var numBeamGroups = 1
public var penaltyAlpha: Double?
public var temperature = 1.0
public var topK = 50
public var topP = 1.0
public var repetitionPenalty = 1.0
public var padTokenId: Int?
public var bosTokenId: Int?
public var eosTokenId: Int?
public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) {
self.maxLength = maxLength
self.maxNewTokens = maxNewTokens
self.doSample = doSample
self.numBeams = numBeams
self.numBeamGroups = numBeamGroups
self.penaltyAlpha = penaltyAlpha
self.temperature = temperature
self.topK = topK
self.topP = topP
self.repetitionPenalty = repetitionPenalty
}
}
public extension GenerationConfig {
var generationMode: GenerationMode {
// Exclude this case from the pattern matching below
if topK > 1, !doSample, penaltyAlpha != nil, penaltyAlpha! > 0 {
return .contrastiveSearch
}
switch (numBeams, numBeamGroups, doSample) {
case (1, 1, false): return .greedy
case (1, 1, true): return .sample
case (2..., 1, false): return .beam
case (2..., 2..., _): return .groupBeam
default: return .unsupported
}
}
}
extension GenerationConfig: Decodable { }