Sources/Hub/Hub.swift (213 lines of code) (raw):
//
// Hub.swift
//
//
// Created by Pedro Cuenca on 18/5/23.
//
import Foundation
public struct Hub { }
public extension Hub {
enum HubClientError: LocalizedError {
case authorizationRequired
case httpStatusCode(Int)
case parse
case unexpectedError
case downloadError(String)
case fileNotFound(String)
case networkError(URLError)
case resourceNotFound(String)
case configurationMissing(String)
case fileSystemError(Error)
case parseError(String)
public var errorDescription: String? {
switch self {
case .authorizationRequired:
String(localized: "Authentication required. Please provide a valid Hugging Face token.")
case let .httpStatusCode(code):
String(localized: "HTTP error with status code: \(code)")
case .parse:
String(localized: "Failed to parse server response.")
case .unexpectedError:
String(localized: "An unexpected error occurred.")
case let .downloadError(message):
String(localized: "Download failed: \(message)")
case let .fileNotFound(filename):
String(localized: "File not found: \(filename)")
case let .networkError(error):
String(localized: "Network error: \(error.localizedDescription)")
case let .resourceNotFound(resource):
String(localized: "Resource not found: \(resource)")
case let .configurationMissing(file):
String(localized: "Required configuration file missing: \(file)")
case let .fileSystemError(error):
String(localized: "File system error: \(error.localizedDescription)")
case let .parseError(message):
String(localized: "Parse error: \(message)")
}
}
}
enum RepoType: String, Codable {
case models
case datasets
case spaces
}
struct Repo: Codable {
public let id: String
public let type: RepoType
public init(id: String, type: RepoType = .models) {
self.id = id
self.type = type
}
}
}
public class LanguageModelConfigurationFromHub {
struct Configurations {
var modelConfig: Config
var tokenizerConfig: Config?
var tokenizerData: Config
}
private var configPromise: Task<Configurations, Error>?
public init(
modelName: String,
revision: String = "main",
hubApi: HubApi = .shared
) {
configPromise = Task.init {
try await self.loadConfig(modelName: modelName, revision: revision, hubApi: hubApi)
}
}
public init(
modelFolder: URL,
hubApi: HubApi = .shared
) {
configPromise = Task {
try await self.loadConfig(modelFolder: modelFolder, hubApi: hubApi)
}
}
public var modelConfig: Config {
get async throws {
try await configPromise!.value.modelConfig
}
}
public var tokenizerConfig: Config? {
get async throws {
if let hubConfig = try await configPromise!.value.tokenizerConfig {
// Try to guess the class if it's not present and the modelType is
if let _: String = hubConfig.tokenizerClass?.string() { return hubConfig }
guard let modelType = try await modelType else { return hubConfig }
// If the config exists but doesn't contain a tokenizerClass, use a fallback config if we have it
if let fallbackConfig = Self.fallbackTokenizerConfig(for: modelType) {
let configuration = fallbackConfig.dictionary()?.merging(hubConfig.dictionary(or: [:]), strategy: { current, _ in current }) ?? [:]
return Config(configuration)
}
// Guess by capitalizing
var configuration = hubConfig.dictionary(or: [:])
configuration["tokenizer_class"] = .init("\(modelType.capitalized)Tokenizer")
return Config(configuration)
}
// Fallback tokenizer config, if available
guard let modelType = try await modelType else { return nil }
return Self.fallbackTokenizerConfig(for: modelType)
}
}
public var tokenizerData: Config {
get async throws {
try await configPromise!.value.tokenizerData
}
}
public var modelType: String? {
get async throws {
try await modelConfig.modelType.string()
}
}
func loadConfig(
modelName: String,
revision: String,
hubApi: HubApi = .shared
) async throws -> Configurations {
let filesToDownload = ["config.json", "tokenizer_config.json", "chat_template.jinja", "chat_template.json", "tokenizer.json"]
let repo = Hub.Repo(id: modelName)
do {
let downloadedModelFolder = try await hubApi.snapshot(from: repo, revision: revision, matching: filesToDownload)
return try await loadConfig(modelFolder: downloadedModelFolder, hubApi: hubApi)
} catch {
// Convert generic errors to more specific ones
if let urlError = error as? URLError {
switch urlError.code {
case .notConnectedToInternet, .networkConnectionLost:
throw Hub.HubClientError.networkError(urlError)
case .resourceUnavailable:
throw Hub.HubClientError.resourceNotFound(modelName)
default:
throw Hub.HubClientError.networkError(urlError)
}
} else {
throw error
}
}
}
func loadConfig(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Configurations {
do {
// Load required configurations
let modelConfigURL = modelFolder.appending(path: "config.json")
guard FileManager.default.fileExists(atPath: modelConfigURL.path) else {
throw Hub.HubClientError.configurationMissing("config.json")
}
let modelConfig = try hubApi.configuration(fileURL: modelConfigURL)
let tokenizerDataURL = modelFolder.appending(path: "tokenizer.json")
guard FileManager.default.fileExists(atPath: tokenizerDataURL.path) else {
throw Hub.HubClientError.configurationMissing("tokenizer.json")
}
let tokenizerData = try hubApi.configuration(fileURL: tokenizerDataURL)
// Load tokenizer config (optional)
var tokenizerConfig: Config? = nil
let tokenizerConfigURL = modelFolder.appending(path: "tokenizer_config.json")
if FileManager.default.fileExists(atPath: tokenizerConfigURL.path) {
tokenizerConfig = try hubApi.configuration(fileURL: tokenizerConfigURL)
}
// Check for chat template and merge if available
// Prefer .jinja template over .json template
var chatTemplate: String? = nil
let chatTemplateJinjaURL = modelFolder.appending(path: "chat_template.jinja")
let chatTemplateJsonURL = modelFolder.appending(path: "chat_template.json")
if FileManager.default.fileExists(atPath: chatTemplateJinjaURL.path) {
// Try to load .jinja template as plain text
chatTemplate = try? String(contentsOf: chatTemplateJinjaURL, encoding: .utf8)
} else if FileManager.default.fileExists(atPath: chatTemplateJsonURL.path),
let chatTemplateConfig = try? hubApi.configuration(fileURL: chatTemplateJsonURL)
{
// Fall back to .json template
chatTemplate = chatTemplateConfig.chatTemplate.string()
}
if let chatTemplate {
// Create or update tokenizer config with chat template
if var configDict = tokenizerConfig?.dictionary() {
configDict["chat_template"] = .init(chatTemplate)
tokenizerConfig = Config(configDict)
} else {
tokenizerConfig = Config(["chat_template": chatTemplate])
}
}
return Configurations(
modelConfig: modelConfig,
tokenizerConfig: tokenizerConfig,
tokenizerData: tokenizerData
)
} catch let error as Hub.HubClientError {
throw error
} catch {
if let nsError = error as NSError? {
if nsError.domain == NSCocoaErrorDomain, nsError.code == NSFileReadNoSuchFileError {
throw Hub.HubClientError.fileSystemError(error)
} else if nsError.domain == "NSJSONSerialization" {
throw Hub.HubClientError.parseError("Invalid JSON format: \(nsError.localizedDescription)")
}
}
throw Hub.HubClientError.fileSystemError(error)
}
}
static func fallbackTokenizerConfig(for modelType: String) -> Config? {
guard let url = Bundle.module.url(forResource: "\(modelType)_tokenizer_config", withExtension: "json") else {
return nil
}
do {
let data = try Data(contentsOf: url)
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
guard let dictionary = parsed as? [NSString: Any] else {
throw Hub.HubClientError.parseError("Failed to parse fallback tokenizer config")
}
return Config(dictionary)
} catch let error as Hub.HubClientError {
print("Error loading fallback tokenizer config: \(error.localizedDescription)")
return nil
} catch {
print("Error loading fallback tokenizer config: \(error.localizedDescription)")
return nil
}
}
}