HuggingChat-Mac/LocalLLM/ModelManager.swift (306 lines of code) (raw):

// // ModelManager.swift // HuggingChat-Mac // // Created by Cyril Zakka on 8/23/24. // import SwiftUI import Path import Combine import UniformTypeIdentifiers import MLXLLM import MLX import MLXRandom import Hub enum LoadState { case idle case loaded(ModelContainer) case error(String) var isError: Bool { if case .error(_) = self { return true } return false } } enum ModelDownloadState: Equatable { case notDownloaded case downloading(progress: Double) case downloaded case error(String) } enum ModelType: Equatable { case llm case stt } // Local representation of HF models @Observable class LocalModel: Identifiable, Hashable { var id : String = UUID().uuidString let displayName: String var size: String? let hfURL: String? var localURL: URL? var icon: String = "laptopcomputer" var modelType : ModelType = .llm var downloadState: ModelDownloadState = .notDownloaded init( id: String = UUID().uuidString, displayName: String, size: String? = nil, hfURL: String? = nil, localURL: URL? = nil, icon: String = "laptopcomputer", modelType: ModelType = .llm, downloadState: ModelDownloadState = .notDownloaded ) { self.id = id self.displayName = displayName self.size = size self.hfURL = hfURL self.localURL = localURL self.icon = icon self.modelType = modelType self.downloadState = downloadState } static func == (lhs: LocalModel, rhs: LocalModel) -> Bool { lhs.id == rhs.id } func hash(into hasher: inout Hasher) { hasher.combine(id) } } @Observable class ModelManager { var availableModels: [LocalModel] = [ LocalModel(id: "Qwen2.5-3B-Instruct-bf16", displayName: "Qwen2.5-3B-Instruct", size: "5.2GB", hfURL: "mlx-community/Qwen2.5-3B-Instruct-bf16", localURL: nil), LocalModel(id: "SmolLM-135M-Instruct-4bit", displayName: "SmolLM-135M-Instruct-4bit", size: "75.8MB", hfURL: "mlx-community/SmolLM-135M-Instruct-4bit") ] private var activeDownloads: [String: Task<Void, Error>] = [:] // MLX Params var globalContainer: ModelContainer? var globalConfig: ModelConfiguration? let generateParameters = GenerateParameters(temperature: 0.6) let maxTokens = 1000 let displayEveryNTokens = 4 var loadState = LoadState.idle var outputText: String = "" var running = false var messages : [[String:String]] = [] init() { self.fetchAllLocalModels() } // MARK: - Model Loading func localModelDidChange(to model: LocalModel) async { loadState = .idle globalConfig = ModelConfiguration(id: model.hfURL!, defaultPrompt: "") do { globalContainer = try await load(modelConfiguration: globalConfig!) loadState = .loaded(globalContainer!) } catch { self.loadState = .error(error.localizedDescription) } } private func load(modelConfiguration: ModelConfiguration) async throws -> ModelContainer { do { MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) let modelContainer = try await MLXLLM.loadModelContainer( configuration: modelConfiguration, progressHandler: { _ in } ) return modelContainer } catch { throw error } } private func load( modelConfiguration: ModelConfiguration, progressCallback: @escaping @Sendable (Progress) -> Void ) async throws -> ModelContainer { do { MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) let modelContainer = try await MLXLLM.loadModelContainer( configuration: modelConfiguration, progressHandler: progressCallback ) return modelContainer } catch { self.loadState = .error(error.localizedDescription) throw error } } func cancelLoading() { loadState = .idle globalContainer = nil // generatedText = "" } // MARK: - Download model func downloadModel(_ model: LocalModel) { guard let modelIndex = availableModels.firstIndex(where: { $0.id == model.id }) else { return } availableModels[modelIndex].downloadState = .downloading(progress: 0) let downloadTask = Task { do { let modelConfig = ModelConfiguration(id: model.hfURL!, defaultPrompt: "") let hub = HubApi() _ = try await prepareModelDirectory( hub: hub, configuration: modelConfig ) { progress in Task { @MainActor in if let idx = self.availableModels.firstIndex(where: { $0.id == model.id }) { self.availableModels[idx].downloadState = .downloading(progress: progress.fractionCompleted) } } } // Update state to downloaded on success await MainActor.run { if let idx = self.availableModels.firstIndex(where: { $0.id == model.id }) { self.availableModels[idx].downloadState = .downloaded } self.fetchAllLocalModels() } } catch { await MainActor.run { if let idx = self.availableModels.firstIndex(where: { $0.id == model.id }) { self.availableModels[idx].downloadState = .error(error.localizedDescription) } self.fetchAllLocalModels() } throw error } } activeDownloads[model.id] = downloadTask } private func prepareModelDirectory( hub: HubApi, configuration: ModelConfiguration, progressHandler: @Sendable @escaping (Progress) -> Void ) async throws -> URL { do { switch configuration.id { case .id(let id): // download the model weights let repo = Hub.Repo(id: id) let modelFiles = ["*.safetensors", "config.json"] return try await hub.snapshot( from: repo, matching: modelFiles, progressHandler: progressHandler) case .directory(let directory): return directory } } catch Hub.HubClientError.authorizationRequired { // an authorizationRequired means (typically) that the named repo doesn't exist on // on the server so retry with local only configuration return configuration.modelDirectory(hub: hub) } catch { let nserror = error as NSError if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet { // Error Domain=NSURLErrorDomain Code=-1009 "The Internet connection appears to be offline." // fall back to the local directory return configuration.modelDirectory(hub: hub) } else { throw error } } } // MARK: - Generate Text func generate(prompt: String) async { guard !running else { return } guard globalContainer != nil else { return } guard globalConfig != nil else { return } running = true self.outputText = "" do { messages.append(["role": "user", "content": prompt]) let promptTokens = try await globalContainer!.perform { _, tokenizer in try tokenizer.applyChatTemplate(messages: messages) } MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) let result = await globalContainer!.perform { model, tokenizer in MLXLLM.generate( promptTokens: promptTokens, parameters: generateParameters, model: model, tokenizer: tokenizer, extraEOSTokens: globalConfig!.extraEOSTokens ) { tokens in if tokens.count % displayEveryNTokens == 0 { let text = tokenizer.decode(tokens: tokens) Task { @MainActor in self.outputText = text } } if tokens.count >= maxTokens { return .stop } else { return .more } } } if result.output != self.outputText { self.outputText = result.output messages.append(["role": "system", "content": result.output]) } } catch { self.loadState = .error(error.localizedDescription) // outputText = "Failed: \(error.localizedDescription)" } running = false } func clearText() { messages = [] outputText = "" } // MARK: - Helper functions func fetchAllLocalModels() { if let documentsPath = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first { do { let items = try FileManager.default.contentsOfDirectory(at: documentsPath.appendingPathComponent("huggingface").appendingPathComponent("models").appendingPathComponent("mlx-community"), includingPropertiesForKeys: [.isDirectoryKey]) let downloadedModelNames = Set(items.map { $0.lastPathComponent }) for (index, model) in availableModels.enumerated() { if let hfURL = model.hfURL { let modelName = hfURL.split(separator: "/").last.map(String.init) ?? "" if downloadedModelNames.contains(modelName) { if let modelPath = items.first(where: { $0.lastPathComponent == modelName }) { let fileSize = getDirectorySize(url: modelPath.standardizedFileURL) // Update the model with local info availableModels[index].downloadState = .downloaded availableModels[index].localURL = modelPath availableModels[index].size = fileSize } } else { // Reset properties if model isn't found locally availableModels[index].downloadState = .notDownloaded availableModels[index].localURL = nil } } } } catch { print("Error fetching local models: \(error.localizedDescription)") } } } func deleteLocalModel(_ model: LocalModel) { guard let localURL = model.localURL else { return } do { try Path(url: localURL)?.delete() } catch { print("Error deleting local model: \(error)") } } func getFileSize(url: URL) -> String { do { let resourceValues = try url.resourceValues(forKeys: [.fileSizeKey]) guard let fileSizeBytes = resourceValues.fileSize else { return "File size unavailable" } let fileSizeMB = Double(fileSizeBytes) / (1024 * 1024) let fileSizeGB = fileSizeMB / 1024 if fileSizeGB >= 1 { return String(format: "%.2f GB", fileSizeGB) } else { return String(format: "%.2f MB", fileSizeMB) } } catch { return "Error: \(error.localizedDescription)" } } func getDirectorySize(url: URL) -> String { let fileManager = FileManager.default var totalSize: Int64 = 0 guard let enumerator = fileManager.enumerator(at: url, includingPropertiesForKeys: [.fileSizeKey, .isDirectoryKey]) else { print("Failed to create enumerator for \(url)") return "0 GB" } for case let fileURL as URL in enumerator { do { let resourceValues = try fileURL.resourceValues(forKeys: [.fileSizeKey, .isDirectoryKey]) if let isDirectory = resourceValues.isDirectory, isDirectory { continue } if let fileSize = resourceValues.fileSize { totalSize += Int64(fileSize) } } catch { print("Error getting size of file \(fileURL): \(error)") } } let formatter = ByteCountFormatter() formatter.allowedUnits = [.useGB, .useMB] formatter.countStyle = .file let sizeInBytes = Int(exactly: totalSize) ?? Int.max let formattedSize = formatter.string(fromByteCount: Int64(sizeInBytes)) return formattedSize } }