HuggingChat-Mac/Models/SendPromptHandler.swift (270 lines of code) (raw):

// // SendPromptHandler.swift // HuggingChat-Mac // // Created by Cyril Zakka on 8/29/24. // import Combine import SwiftUI import Foundation import AppKit extension NSFont { func apply(newTraits: NSFontDescriptor.SymbolicTraits, newPointSize: CGFloat? = nil) -> NSFont { var existingTraits = fontDescriptor.symbolicTraits existingTraits.insert(newTraits) let newFontDescriptor = fontDescriptor.withSymbolicTraits(existingTraits) if let newFont = NSFont(descriptor: newFontDescriptor, size: newPointSize ?? pointSize) { return newFont } else { return self } } } struct FileMessage: Decodable { let name: String let sha: String let mime: String } struct StreamMessage: Decodable { let type: String let token: String? let subtype: String? let message: String? let sources: [WebSearchSource]? let name: String? // For file messages let sha: String? // For file messages let mime: String? // For file messages enum CodingKeys: CodingKey { case type case token case subtype case message case sources case name case sha case mime } init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) self.type = try container.decode(String.self, forKey: .type) self.token = try container.decodeIfPresent(String.self, forKey: .token)?.trimmingCharacters(in: .nulls) ?? "" self.subtype = try container.decodeIfPresent(String.self, forKey: .subtype) self.message = try container.decodeIfPresent(String.self, forKey: .message) self.sources = try container.decodeIfPresent([WebSearchSource].self, forKey: .sources) self.name = try container.decodeIfPresent(String.self, forKey: .name) self.sha = try container.decodeIfPresent(String.self, forKey: .sha) self.mime = try container.decodeIfPresent(String.self, forKey: .mime) } } struct WebSearchSource: Identifiable, Decodable { var id = UUID() let link: URL let title: String let hostname: String enum CodingKeys: CodingKey { case link case title case hostname } init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) self.link = try container.decode(URL.self, forKey: .link) self.title = try container.decode(String.self, forKey: .title) do { self.hostname = (try container.decode(String.self, forKey: .hostname)).deletingPrefix("www.") } catch { self.hostname = link.host?.deletingPrefix("www.") ?? link.absoluteString.deletingPrefix("www.") } } // Convenience Init init(link: URL, title: String, hostname: String) { self.link = link self.title = title self.hostname = hostname } } final class WebSearch { var message: String var sources: [WebSearchSource] init(message: String, sources: [WebSearchSource]) { self.message = message self.sources = sources } } enum StreamWebSearch { case message(String) case sources([WebSearchSource]) } enum StreamMessageType { case started case token(String) case webSearch(StreamWebSearch) case file(FileMessage) case skip static func messageType(from json: StreamMessage) -> StreamMessageType? { switch json.type { case "webSearch": return webSearch(from: json) case "stream": return .token(json.token ?? "") case "file": if let name = json.name, let sha = json.sha, let mime = json.mime { return .file(FileMessage(name: name, sha: sha, mime: mime)) } return .skip case "title": return .skip default: return .skip } } private static func webSearch(from json: StreamMessage) -> StreamMessageType? { guard let messageType = json.subtype else { return nil } switch messageType { case "sources": return .webSearch(.sources(json.sources ?? [])) case "update": return .webSearch(.message(json.message ?? "")) default: return nil } } } final class SendPromptHandler { var isDarkMode: Bool { guard let window = NSApp.keyWindow else { return NSApp.effectiveAppearance.bestMatch(from: [.darkAqua, .aqua]) == .darkAqua } return window.effectiveAppearance.bestMatch(from: [.darkAqua, .aqua]) == .darkAqua } private static let throttleTime: DispatchQueue.SchedulerTimeType.Stride = .milliseconds(100) private var privateUpdate: PassthroughSubject<StreamMessageType, HFError> = PassthroughSubject< StreamMessageType, HFError >() private var responseMessage: String = "" private var currentTextCount: Int = 0 private let conversationId: String private let parserThresholdTextCount = 0 private var currentOutput: AttributedOutput? private var cancellables: [AnyCancellable] = [] var messageRow: MessageRow private var postPrompt: PostStream? = PostStream() var update: AnyPublisher<MessageRow, HFError> { return privateUpdate .map({ [weak self] (messageType: StreamMessageType) -> MessageRow? in guard let self else { fatalError() } return self.updateMessageRow(with: messageType) }) .compactMap({ $0 }) .eraseToAnyPublisher() // .throttle( // for: SendPromptHandler.throttleTime, scheduler: DispatchQueue.main, latest: true // ).eraseToAnyPublisher() } init(conversationId: String) { self.conversationId = conversationId self.messageRow = MessageRow( type: .assistant, isInteracting: true, contentType: .rawText(" ")) } var tmpMessage: String = "" private let decoder: JSONDecoder = JSONDecoder() func sendPromptReq(reqBody: PromptRequestBody) { postPrompt?.postPrompt(reqBody: reqBody, conversationId: conversationId).sink(receiveCompletion: { [weak self] completion in switch completion { case .finished: self?.privateUpdate.send(completion: .finished) case .failure(let error): print("error \(error)") self?.privateUpdate.send(completion: .failure(error)) } }, receiveValue: { [weak self] (data: Data) in guard let self = self, let message = String(data: data, encoding: .utf8) else { return } let messages = message.split(separator: "\n") for m in messages { self.tmpMessage = self.tmpMessage + m guard let sd = self.tmpMessage.data(using: .utf8) else { continue } if let json = try? self.decoder.decode(StreamMessage.self, from: sd), json.type == "file", let name = json.name, let sha = json.sha, let mime = json.mime { let fileMessage = FileMessage(name: name, sha: sha, mime: mime) self.privateUpdate.send(.file(fileMessage)) } guard let json = try? self.decoder.decode(StreamMessage.self, from: sd) else { continue } self.tmpMessage = "" self.privateUpdate.send(StreamMessageType.messageType(from: json) ?? .skip) } }).store(in: &cancellables) } lazy var parsingTask = ResponseParsingTask(isDarkMode: isDarkMode) var attributedSend: AttributedOutput = AttributedOutput(string: "", results: []) private func updateMessageRow(with message: StreamMessageType) -> MessageRow? { switch message { case .started: return messageRow case .webSearch(let update): if messageRow.webSearch == nil { messageRow.webSearch = WebSearch(message: "", sources: []) } switch update { case .message(let message): messageRow.webSearch?.message = message case .sources(let sources): messageRow.webSearch?.sources = sources } return messageRow case .token(let token): messageRow.webSearch?.message = "Completed" return updateMessage(with: token) case .skip: return nil case .file(let fileMessage): messageRow.fileInfo = fileMessage return messageRow } } private func updateMessage(with token: String) -> MessageRow { attributedSend = parsingTask.parse(text: token) responseMessage += token currentTextCount += token.count if currentTextCount >= parserThresholdTextCount || token.contains("```") { currentOutput = parsingTask.parse(text: responseMessage) currentTextCount = 0 } if let currentOutput = currentOutput, !currentOutput.results.isEmpty { let suffixText = responseMessage.deletingPrefix(currentOutput.string) var results = currentOutput.results let lastResult = results[results.count - 1] let lastAttrString = lastResult.attributedString if case .codeBlock(_) = lastResult.resultType { lastAttrString.append( NSMutableAttributedString(string: String(suffixText), attributes: .init([ .font: NSFont.systemFont(ofSize: 12).apply(newTraits: .monoSpace), .foregroundColor: NSColor.white, ]))) } else { lastAttrString.append(NSMutableAttributedString(string: String(suffixText))) } results[results.count - 1] = ParserResult( attributedString: lastAttrString, resultType: lastResult.resultType) messageRow.contentType = .attributed(.init(string: responseMessage, results: results)) } else { messageRow.contentType = .attributed( .init( string: responseMessage, results: [ ParserResult( attributedString: NSMutableAttributedString(string: responseMessage), resultType: .text) ])) } if let currentString = currentOutput?.string, currentString != responseMessage { let output = parsingTask.parse(text: responseMessage) messageRow.contentType = .attributed(output) } return messageRow } func cancel() { cancellables = [] postPrompt = nil } }