Sources/TransformersCLI/main.swift (93 lines of code) (raw):
import ArgumentParser
import CoreML
import Foundation
import Generation
import Models
@available(iOS 16.2, macOS 13.1, *)
struct TransformersCLI: ParsableCommand {
static let configuration = CommandConfiguration(
abstract: "Run text generation on a Core ML language model",
version: "0.0.1"
)
@Argument(help: "Input text")
var prompt: String
@Argument(help: "Path to Core ML mlpackage model")
var modelPath: String = "./model.mlpackage"
@Option(help: "Maximum amount of tokens the model should generate")
var maxLength: Int = 50
@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
var computeUnits: ComputeUnits = .cpuAndGPU
func generate(model: LanguageModel, config: GenerationConfig, prompt: String, printOutput: Bool = true) {
let semaphore = DispatchSemaphore(value: 0)
Task.init { [config] in
defer { semaphore.signal() }
var tokensReceived = 0
var previousIndex: String.Index? = nil
let begin = Date()
do {
try await model.generate(config: config, prompt: prompt) { inProgressGeneration in
tokensReceived += 1
let response = inProgressGeneration.replacingOccurrences(of: "\\n", with: "\n")
if printOutput {
print(response[(previousIndex ?? response.startIndex)...], terminator: "")
fflush(stdout)
}
previousIndex = response.endIndex
}
let completionTime = Date().timeIntervalSince(begin)
let tps = Double(tokensReceived) / completionTime
if printOutput {
print("")
print("\(tps.formatted("%.2f")) tokens/s, total time: \(completionTime.formatted("%.2f"))s")
}
} catch {
print("Error \(error)")
}
}
semaphore.wait()
}
func compile(at url: URL) throws -> URL {
#if os(watchOS)
fatalError("Model compilation is not supported on watchOS")
#else
if url.pathExtension == "mlmodelc" { return url }
print("Compiling model \(url)")
return try MLModel.compileModel(at: url)
#endif
}
func run() throws {
let url = URL(filePath: modelPath)
let compiledURL = try compile(at: url)
print("Loading model \(compiledURL)")
let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits)
// Using greedy generation for now
var config = model.defaultGenerationConfig
config.doSample = false
config.maxNewTokens = maxLength
print("Warming up...")
generate(model: model, config: config, prompt: prompt, printOutput: false)
print("Generating")
generate(model: model, config: config, prompt: prompt)
}
}
@available(iOS 16.2, macOS 13.1, *)
enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine
var asMLComputeUnits: MLComputeUnits {
switch self {
case .all: .all
case .cpuAndGPU: .cpuAndGPU
case .cpuOnly: .cpuOnly
case .cpuAndNeuralEngine: .cpuAndNeuralEngine
}
}
}
if #available(iOS 16.2, macOS 13.1, *) {
TransformersCLI.main()
} else {
print("Unsupported OS")
}
extension Double {
func formatted(_ format: String) -> String {
String(format: "\(format)", self)
}
}