Sources/Models/LanguageModel.swift (174 lines of code) (raw):
//
// LanguageModel.swift
//
//
// Created by Pedro Cuenca on 7/5/23.
//
import CoreML
import Generation
import Hub
import Tokenizers
public class LanguageModel {
public let model: MLModel
public let minContextLength: Int
public let maxContextLength: Int
let input_ids = "input_ids"
let attention_mask = "attention_mask"
struct Configurations {
var modelConfig: Config
var tokenizerConfig: Config?
var tokenizerData: Config
}
private var configuration: LanguageModelConfigurationFromHub?
private var _tokenizer: Tokenizer?
public required init(model: MLModel) {
self.model = model
// We assume inputs named "input_ids" with shape (1, seq_length)
// Perhaps we should convert to vectors of shape (seq_length) and use sequenceConstraint instead of shapeConstraint
let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"]
guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else {
fatalError("Cannot obtain shape information")
}
switch shapeConstraint.type {
case .enumerated:
// TODO: support a set of fixed shapes (keeping the first one here)
minContextLength = shapeConstraint.enumeratedShapes[0][1].intValue
maxContextLength = minContextLength
case .range:
let range = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension[1] as? NSRange
minContextLength = range?.location ?? 1
maxContextLength = range?.length ?? 128
case .unspecified:
minContextLength = 128
maxContextLength = 128
@unknown default:
minContextLength = 128
maxContextLength = 128
}
configuration = LanguageModelConfigurationFromHub(modelName: modelName)
}
}
public extension LanguageModel {
static func loadCompiled(url: URL, computeUnits: MLComputeUnits = .cpuAndGPU) throws -> LanguageModel {
let config = MLModelConfiguration()
config.computeUnits = computeUnits
let model = try MLModel(contentsOf: url, configuration: config)
return LanguageModel(model: model)
}
}
public extension LanguageModel {
var description: String {
if let description = model.modelDescription.metadata[MLModelMetadataKey.description] as? String,
!description.isEmpty
{
return description
}
return model.configuration.modelDisplayName ?? ""
}
/// `name_or_path` in the Python world
var modelName: String {
if let userFields = model.modelDescription.metadata[MLModelMetadataKey.creatorDefinedKey] as? [String: String],
let name = userFields["co.huggingface.exporters.name"]
{
return name
}
// This is usually the basename of the file, that's our best bet if no metadata exists
guard let modelName = model.configuration.modelDisplayName else { fatalError("Models must have a name that identifies them") }
return modelName
}
var inputIdsDescription: MLFeatureDescription {
model.modelDescription.inputDescriptionsByName[input_ids]!
}
var inputIdsName: String {
inputIdsDescription.name
}
/// The expected shape of the models latent sample input
var inputIdsShape: [Int] {
inputIdsDescription.multiArrayConstraint!.shape.map { $0.intValue }
}
var requiresAttention: Bool {
model.modelDescription.inputDescriptionsByName[attention_mask] != nil
}
/// MLShapedArrayProtocol is either a MLShapedArray or a MLShapedArraySlice
func predictNextTokenScores(_ tokens: InputTokens, config: GenerationConfig) -> any MLShapedArrayProtocol {
// TODO: exceptions
// Maybe pad or truncate
let maxTokens = min(tokens.count, maxContextLength)
let padLength = maxTokens >= minContextLength ? 0 : minContextLength - maxTokens
let inputTokens = Array(tokens[0..<maxTokens]) + Array(repeating: config.padTokenId ?? 0, count: padLength)
let inputIds = MLShapedArray<Int32>(scalars: inputTokens.map { Int32($0) }, shape: inputIdsShape)
var inputDictionary = [inputIdsName: MLFeatureValue(shapedArray: inputIds)]
if requiresAttention {
let mask = Array(repeating: 1, count: maxTokens) + Array(repeating: 0, count: padLength)
let attentionMask = MLShapedArray<Int32>(scalars: mask.map { Int32($0) }, shape: inputIdsShape)
inputDictionary[attention_mask] = MLFeatureValue(shapedArray: attentionMask)
}
let input = try! MLDictionaryFeatureProvider(dictionary: inputDictionary)
let output = try! model.prediction(from: input)
// TODO: maybe try to support models with "token_scores" too (after the softmax)
assert(output.featureNames.first! == "logits")
let scores = output.featureValue(for: output.featureNames.first!)!.shapedArrayValue(of: Float.self)!
let nextTokenScores = scores[0, maxTokens - 1]
return nextTokenScores
}
}
/// async properties downloaded from the configuration
public extension LanguageModel {
var modelConfig: Config {
get async throws {
try await configuration!.modelConfig
}
}
var tokenizerConfig: Config? {
get async throws {
try await configuration!.tokenizerConfig
}
}
var tokenizerData: Config {
get async throws {
try await configuration!.tokenizerData
}
}
var modelType: String? {
get async throws {
try await modelConfig.modelType.string()
}
}
var textGenerationParameters: Config? {
get async throws {
try await modelConfig.taskSpecificParams.textGeneration
}
}
var defaultDoSample: Bool {
get async throws {
try await textGenerationParameters?.doSample.boolean() ?? true
}
}
var bosTokenId: Int? {
get async throws {
let modelConfig = try await modelConfig
return modelConfig.bosTokenId.integer()
}
}
var eosTokenId: Int? {
get async throws {
let modelConfig = try await modelConfig
return modelConfig.eosTokenId.integer()
}
}
var tokenizer: Tokenizer {
get async throws {
guard _tokenizer == nil else { return _tokenizer! }
guard let tokenizerConfig = try await tokenizerConfig else {
throw TokenizerError.tokenizerConfigNotFound
}
let tokenizerData = try await tokenizerData
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
return _tokenizer!
}
}
}
extension LanguageModel: TextGenerationModel {
// TODO: retrieve from the json: https://huggingface.co/nlpcloud/instruct-gpt-j-fp16/blob/main/config.json#L26
public var defaultGenerationConfig: GenerationConfig {
var config = GenerationConfig(maxNewTokens: 30)
switch modelName.lowercased() {
case let x where x.contains("gpt"):
config.doSample = true
config.topK = 50
default: break
}
return config
}
}
public enum TokenizerError: LocalizedError {
case tokenizerConfigNotFound
public var errorDescription: String? {
switch self {
case .tokenizerConfigNotFound:
String(localized: "Tokenizer configuration could not be found. The model may be missing required tokenizer files.", comment: "Error when tokenizer configuration is missing")
}
}
}