HuggingChat-Mac/Models/ConversationModel.swift (279 lines of code) (raw):
//
// ConversationModel.swift
// HuggingChat-Mac
//
// Created by Cyril Zakka on 8/29/24.
//
import SwiftUI
import Combine
enum ConversationState: Equatable {
case none, empty, loaded, loading, generating, error
}
@Observable final class ConversationViewModel {
var isInteracting = false
var isMultimodal: Bool = false
var isTools: Bool = false
var model: AnyObject?
var message: MessageRow? = nil
var messages: [MessageRow] = [
// MessageRow(
// type: .user,
// isInteracting: false,
// contentType: .rawText("What is the meaning of life?")
// ),
// MessageRow(
// type: .assistant,
// isInteracting: false,
// contentType: .rawText("""
//### How to Sort a List in Python
//
//1. **Sort a List of Numbers:**
// ```python
// numbers = [5, 2, 9, 1, 3]
// numbers.sort()
// print(numbers)
//""")
// ),
]
var error: HFError?
// Tools
var imageURL: String?
// Context
var contextAppName: String?
var contextAppSelectedText: String?
var contextAppFullText: String?
var contextAppIcon: NSImage?
var contextIsSupported: Bool = false
// Currently the best way to get @AppStorage value while returning observability
var useWebService: Bool {
get {
access(keyPath: \.useWebService)
return UserDefaults.standard.bool(forKey: "useWebSearch")
}
set {
withMutation(keyPath: \.useWebService) {
UserDefaults.standard.setValue(newValue, forKey: "useWebSearch")
}
}
}
var useContext: Bool {
get {
access(keyPath: \.useContext)
return UserDefaults.standard.bool(forKey: "useContext")
}
set {
withMutation(keyPath: \.useContext) {
UserDefaults.standard.setValue(newValue, forKey: "useContext")
}
}
}
var externalModel: String {
get {
access(keyPath: \.externalModel)
return UserDefaults.standard.string(forKey: "externalModel") ?? "meta-llama/Meta-Llama-3.1-70B-Instruct"
}
set {
withMutation(keyPath: \.externalModel) {
UserDefaults.standard.setValue(newValue, forKey: "externalModel")
}
}
}
private var cancellables = [AnyCancellable]()
private var sendPromptHandler: SendPromptHandler?
private(set) var conversation: Conversation? {
didSet {
guard let conversation = conversation else { return }
HuggingChatSession.shared.currentConversation = conversation.serverId
}
}
var state: ConversationState = .none
func loadConversation(_ conversation: Conversation) {
self.conversation = conversation
HuggingChatSession.shared.currentConversation = conversation.serverId
loadHistory()
}
private func loadHistory() {
guard let conversation = conversation else { return }
state = .loading
NetworkService.getConversation(id: conversation.serverId)
.receive(on: DispatchQueue.main)
.map { [weak self] (conversation: Conversation) -> [MessageRow] in
guard let self else { return [] }
self.conversation = conversation
return self.buildHistory(conversation: conversation)
}
.sink { completion in
switch completion {
case .finished: break
case .failure(let error):
print("Error loading conversation: \(error.localizedDescription)")
}
} receiveValue: { [weak self] messages in
self?.messages = messages
// self?.internalDelegate?.reloadData()
// self?.internalDelegate?.scrollToBottom(animated: false)
self?.state = .loaded
}.store(in: &cancellables)
}
private func createConversationAndSendPrompt(_ prompt: String, withFiles: [String]? = nil, usingTools: [String]? = nil) {
if let model = model as? LLMModel {
createConversation(with: model, prompt: prompt, withFiles: withFiles, usingTools: usingTools)
}
}
private func createConversation(with model: LLMModel, prompt: String, withFiles: [String]? = nil, usingTools: [String]? = nil) {
state = .loaded
NetworkService.createConversation(base: model)
.receive(on: DispatchQueue.main).sink { completion in
switch completion {
case .finished:
print("ConversationViewModel.createConversation finished")
case .failure(let error):
print("ConversationViewModel.createConversation failed:\n\(error)")
self.state = .error
self.error = .verbose("Something's wrong. Check your internet connection and try again.")
}
} receiveValue: { [weak self] conversation in
print("Recieved")
self?.conversation = conversation
self?.sendAttributed(text: prompt, withFiles: withFiles)
}.store(in: &cancellables)
}
func sendAttributed(text: String, withFiles: [String]? = nil) {
guard let conversation = conversation, let previousId = conversation.messages.last?.id else {
createConversationAndSendPrompt(text, withFiles: withFiles, usingTools: isTools ? []:nil)
return
}
var trimmedText = ""
if useContext {
if let contextAppSelectedText = contextAppSelectedText {
trimmedText += "Selected Text: ```\(contextAppSelectedText)```"
}
if let contextAppFullText = contextAppFullText {
// TODO: Truncate full context if needed
trimmedText += "\n\nFull Text:```\(contextAppFullText)```"
}
}
trimmedText += text.trimmingCharacters(in: .whitespaces)
// TODO: Add files here
let userMessage = MessageRow(type: .user, isInteracting: false, contentType: .rawText(trimmedText))
messages.append(userMessage)
let req = PromptRequestBody(id: previousId, inputs: trimmedText, webSearch: useWebService, files: withFiles, tools: isTools ? ["000000000000000000000001", "000000000000000000000002", "00000000000000000000000a"] : nil)
sendPromptRequest(req: req, conversationID: conversation.serverId)
}
func sendTranscript(text: String) {
guard let conversation = conversation, let previousId = conversation.messages.last?.id else {
createConversationAndSendPrompt(text, withFiles: nil, usingTools: nil)
return
}
let trimmedText = text.trimmingCharacters(in: .whitespaces)
let req = PromptRequestBody(id: previousId, inputs: trimmedText, webSearch: useWebService, files: nil, tools: nil)
sendPromptRequest(req: req, conversationID: conversation.serverId)
}
private func sendPromptRequest(req: PromptRequestBody, conversationID: String) {
state = .generating
isInteracting = true
imageURL = nil
let sendPromptHandler = SendPromptHandler(conversationId: conversationID)
self.sendPromptHandler = sendPromptHandler
let messageRow = sendPromptHandler.messageRow
messages.append(messageRow)
let pub = sendPromptHandler.update
.receive(on: RunLoop.main).eraseToAnyPublisher()
pub.scan((0, messageRow)) { (tuple, newMessage) in
(tuple.0 + 1, newMessage)
}.eraseToAnyPublisher()
.sink { [weak self] completion in
guard let self else { return }
switch completion {
case .finished:
self.sendPromptHandler = nil
isInteracting = false
self.sendPromptHandler = nil
state = .loaded
case .failure(let error):
switch error {
case .httpTooManyRequest:
self.messages.removeLast(2)
self.state = .error
self.error = .verbose("You've sent too many requests. Please try logging in before sending a message.")
print(error.localizedDescription)
default:
self.state = .error
self.error = error
print(error.localizedDescription)
}
}
} receiveValue: { [weak self] obj in
guard let self else { return }
let (count, messageRow) = obj
if count == 1 {
self.updateConversation(conversationID: conversationID)
}
self.message = messageRow
print(messageRow)
if let lastIndex = self.messages.lastIndex(where: { $0.id == messageRow.id }) {
self.messages[lastIndex] = messageRow
}
if let fileInfo = self.message?.fileInfo,
fileInfo.mime.hasPrefix("image/"),
let conversationID = self.conversation?.id {
self.imageURL = "https://huggingface.co/chat/conversation/\(conversationID)/output/\(fileInfo.sha)"
}
}.store(in: &cancellables)
sendPromptHandler.sendPromptReq(reqBody: req)
}
private func updateConversation(conversationID: String) {
NetworkService.getConversation(id: conversationID).sink { completion in
switch completion {
case .finished:
print("ConversationViewModel.updateConversation finished")
case .failure(let error):
self.state = .error
self.error = .verbose("Uh oh, something's not right! Please check your connection and try again later.")
print(error.localizedDescription)
}
} receiveValue: { [weak self] conversation in
self?.conversation = conversation
}.store(in: &cancellables)
}
func getActiveModel() {
DataService.shared.getActiveModel().receive(on: DispatchQueue.main).sink { completion in
switch completion {
case .finished:
print("ConversationViewModel.getActiveModel finished")
case .failure(let error):
self.state = .error
self.error = .verbose("Hmm, that didn't go as planned. Please check your connection and try again.")
print("ConversationViewModel.getActiveModel failed:\n \(error)")
}
} receiveValue: { [weak self] model in
self?.model = model
self?.externalModel = (model as! LLMModel).name
self?.isMultimodal = (model as! LLMModel).multimodal
self?.isTools = (model as! LLMModel).tools
}.store(in: &cancellables)
}
private func buildHistory(conversation: Conversation) -> [MessageRow] {
let messages = conversation.messages.compactMap({ (message: Message) -> MessageRow? in
return MessageRow(message: message)
})
// let historyParser = HistoryParser(isDarkMode: isDarkMode)
// messages = historyParser.parseMessages(messages: messages)
return messages
}
func reset() {
state = .empty
getActiveModel()
cancellables = []
conversation = nil
error = nil
isInteracting = false
HuggingChatSession.shared.currentConversation = ""
clearContext()
}
func stopGenerating() {
cancellables = []
sendPromptHandler?.cancel()
completeInteration()
}
private func completeInteration() {
isInteracting = false
sendPromptHandler = nil
state = .loaded
error = nil
}
// MARK: Context Functions
func fetchContext() {
self.contextAppName = nil
self.contextAppSelectedText = nil
self.contextAppFullText = nil
self.contextAppIcon = nil
self.contextIsSupported = false
Task {
if let content = await AccessibilityContentReader.shared.getActiveEditorContent() {
await MainActor.run {
self.contextIsSupported = content.isSupported
self.contextAppName = content.applicationName
self.contextAppIcon = content.applicationIcon
if self.contextIsSupported {
self.contextAppSelectedText = content.selectedText
self.contextAppFullText = content.fullText
}
}
}
}
}
func formatContext() {
// TODO: Truncate contextAppFullText from start to 3000 characters.
}
func clearContext() {
contextAppName = nil
contextAppSelectedText = nil
contextAppFullText = nil
}
}