SwiftChat/ContentView.swift (232 lines of code) (raw):

// // ContentView.swift // SwiftChat // // Created by Pedro Cuenca on April 2023 // Based on code by Cyril Zakka from https://github.com/cyrilzakka/pen // import SwiftUI import Generation import Models enum ModelState: Equatable { case noModel case loading case ready(Double?) case generating(Double) case failed(String) } struct ContentView: View { @Environment(\.horizontalSizeClass) private var horizontalSizeClass @Environment(\.verticalSizeClass) private var verticalSizeClass @State private var config = GenerationConfig(maxNewTokens: 20) @State private var prompt = "Write a poem about Valencia\n" @State private var modelURL: URL? = nil @State private var languageModel: LanguageModel? = nil @State private var isSettingsPresented = false @State private var isFirstLaunch = true @State private var status: ModelState = .noModel @State private var outputText: AttributedString = "" @Binding var clearTriggered: Bool func modelDidChange() { guard status != .loading else { return } status = .loading Task.init { do { languageModel = try await ModelLoader.load(url: modelURL) if let config = languageModel?.defaultGenerationConfig { let maxNewTokens = self.config.maxNewTokens self.config = config Task.init { // Refresh after slider limits have been updated self.config.maxNewTokens = min(maxNewTokens, languageModel?.maxContextLength ?? 20) } } status = .ready(nil) isSettingsPresented = false } catch { print("No model could be loaded: \(error)") status = .noModel } } } func clear() { outputText = "" } func run() { guard let languageModel = languageModel else { return } @Sendable func showOutput(currentGeneration: String, progress: Double, completedTokensPerSecond: Double? = nil) { Task { @MainActor in // Temporary hack to remove start token returned by llama tokenizers var response = currentGeneration.deletingPrefix("<s> ") // Strip prompt guard response.count > prompt.count else { return } response = response[prompt.endIndex...].replacingOccurrences(of: "\\n", with: "\n") // Format prompt + response with different colors var styledPrompt = AttributedString(prompt) styledPrompt.foregroundColor = .black var styledOutput = AttributedString(response) styledOutput.foregroundColor = .accentColor outputText = styledPrompt + styledOutput if let tps = completedTokensPerSecond { status = .ready(tps) } else { status = .generating(progress) } } } Task.init { status = .generating(0) var tokensReceived = 0 let begin = Date() do { let output = try await languageModel.generate(config: config, prompt: prompt) { inProgressGeneration in tokensReceived += 1 showOutput(currentGeneration: inProgressGeneration, progress: Double(tokensReceived)/Double(config.maxNewTokens)) } let completionTime = Date().timeIntervalSince(begin) let tokensPerSecond = Double(tokensReceived) / completionTime showOutput(currentGeneration: output, progress: 1, completedTokensPerSecond: tokensPerSecond) print("Took \(completionTime)") } catch { print("Error \(error)") Task { @MainActor in status = .failed("\(error)") } } } } @ViewBuilder var runButton: some View { switch status { case .noModel: EmptyView() case .loading: ProgressView().controlSize(.small).padding(.trailing, 6) case .ready, .failed: Button(action: run) { Label("Run", systemImage: "play.fill") } .keyboardShortcut("R") case .generating(let progress): ProgressView(value: progress).controlSize(.small).progressViewStyle(.circular).padding(.trailing, 6) } } var chatView: some View { GeometryReader { geometry in VStack { VStack(alignment: .leading, spacing: 4) { Text("Your input (use format appropriate for the model you are using)") .font(.caption) .foregroundColor(.gray) TextEditor(text: $prompt) .font(.body) .fontDesign(.rounded) .scrollContentBackground(.hidden) .multilineTextAlignment(.leading) .padding(.all, 4) .overlay( RoundedRectangle(cornerRadius: 8) .stroke(Color.gray.opacity(0.5), lineWidth: 1) ) } .frame(height: 100) .padding(.bottom, 16) VStack(alignment: .leading, spacing: 4) { Text("Language Model Output") .font(.caption) .foregroundColor(.gray) Text(outputText) .font(.system(size: 14)) .foregroundColor(.blue) .multilineTextAlignment(.leading) .lineLimit(nil) .frame(minWidth: geometry.size.width - 44, minHeight: 200, alignment: Alignment(horizontal: .leading, vertical: .top)) .padding(.all, 4) .overlay( RoundedRectangle(cornerRadius: 8) .stroke(Color.gray.opacity(0.5), lineWidth: 1) ) .onChange(of: clearTriggered) { _, _ in clear() } } } .padding() .toolbar { ToolbarItem(placement: .primaryAction) { runButton } } } .navigationTitle("Language Model Tester") } var regularView: some View { NavigationSplitView { VStack { ControlView(prompt: prompt, config: $config, model: $languageModel, modelURL: $modelURL) StatusView(status: $status) } .navigationSplitViewColumnWidth(min: 250, ideal: 300) } detail: { chatView } } #if os(iOS) var compactView: some View { NavigationView { VStack { chatView StatusView(status: $status) } .navigationBarTitleDisplayMode(.inline) .toolbar { ToolbarItem(placement: .topBarLeading) { Button("Settings", systemImage: "gear") { isSettingsPresented = true } } } } .onAppear { if isFirstLaunch { isSettingsPresented = true isFirstLaunch = false } } .sheet(isPresented: $isSettingsPresented) { NavigationView { VStack { ControlView(prompt: prompt, config: $config, model: $languageModel, modelURL: $modelURL) StatusView(status: $status) } .navigationTitle("Settings") .navigationBarTitleDisplayMode(.inline) .toolbar { ToolbarItem(placement: .confirmationAction) { Button("Done") { isSettingsPresented = false } } } } } } #endif var body: some View { Group { #if os(iOS) if horizontalSizeClass == .compact && (verticalSizeClass == .compact || verticalSizeClass == .regular) { compactView } else { regularView } #else regularView #endif } .onAppear { modelDidChange() } .onChange(of: modelURL) { modelDidChange() } } } struct ContentView_Previews: PreviewProvider { static var previews: some View { ContentView(clearTriggered: .constant(false)) } } extension String { func deletingPrefix(_ prefix: String) -> String { guard hasPrefix(prefix) else { return self } return String(dropFirst(prefix.count)) } }