HuggingChat-Mac/LocalSTT/AudioModelManager.swift (742 lines of code) (raw):

// // AudioModelManager.swift // HuggingChat-Mac // // Created by Cyril Zakka on 9/9/24. // import SwiftUI import WhisperKit import AppKit import AVFoundation import CoreML enum TranscriptionMode { case streaming case recording } enum TranscriptionSource { case chat case transcriptionView case none } @Observable class AudioModelManager { var whisperKit: WhisperKit? = nil var audioDevices: [AudioDevice]? = nil var isRecording: Bool = false var isTranscribing: Bool = false var currentText: String = "" var currentChunks: [Int: (chunkText: [String], fallbacks: Int)] = [:] var modelStorage: String = "huggingface/models/argmaxinc/whisperkit-coreml" var modelState: ModelState = .unloaded var localModels: [String] = [] var localModelPath: String = "" var availableModels: [String] = [] var availableLanguages: [String] = [] var disabledModels: [String] = WhisperKit.recommendedModels().disabled var repoName: String = "argmaxinc/whisperkit-coreml" var silenceThreshold: Double = 0.3 var isTranscriptionComplete: Bool = false var transcriptionSource: TranscriptionSource = .none var selectedAudioInput: String { get { access(keyPath: \.selectedAudioInput) return UserDefaults.standard.string(forKey: "selectedAudioInput") ?? "None" } set { withMutation(keyPath: \.selectedAudioInput) { UserDefaults.standard.setValue(newValue, forKey: "selectedAudioInput") } } } private var selectedTask: String = "transcribe" private var selectedLanguage: String = "english" private var enableTimestamps: Bool = false private var enablePromptPrefill: Bool = true private var enableCachePrefill: Bool = true private var enableSpecialCharacters: Bool = false private var enableEagerDecoding: Bool = true private var temperatureStart: Double = 0 private var fallbackCount: Double = 5 private var compressionCheckWindow: Double = 60 private var sampleLength: Double = 224 private var useVAD: Bool = false private var tokenConfirmationsNeeded: Double = 2 private var chunkingStrategy: ChunkingStrategy = .none private var encoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine private var decoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine // MARK: Standard properties var loadingProgressValue: Float = 0.0 var specializationProgressRatio: Float = 0.7 var firstTokenTime: TimeInterval = 0 var pipelineStart: TimeInterval = 0 var effectiveRealTimeFactor: TimeInterval = 0 var effectiveSpeedFactor: TimeInterval = 0 var totalInferenceTime: TimeInterval = 0 var tokensPerSecond: TimeInterval = 0 var currentLag: TimeInterval = 0 var currentFallbacks: Int = 0 var currentEncodingLoops: Int = 0 var currentDecodingLoops: Int = 0 var lastBufferSize: Int = 0 var lastConfirmedSegmentEndSeconds: Float = 0 var requiredSegmentsForConfirmation: Int = 4 var bufferEnergy: [Float] = [] var bufferSeconds: Double = 0 var confirmedSegments: [TranscriptionSegment] = [] var unconfirmedSegments: [TranscriptionSegment] = [] // MARK: Eager mode properties var eagerResults: [TranscriptionResult?] = [] var prevResult: TranscriptionResult? var lastAgreedSeconds: Float = 0.0 var prevWords: [WordTiming] = [] var lastAgreedWords: [WordTiming] = [] var confirmedWords: [WordTiming] = [] var confirmedText: String = "" var hypothesisWords: [WordTiming] = [] var hypothesisText: String = "" // MARK: UI properties private var transcriptionMode: TranscriptionMode = .recording private var transcriptionTask: Task<Void, Never>? = nil private var transcribeTask: Task<Void, Never>? = nil // MARK: Compatibility var availableLocalModels: [LocalModel] = [ LocalModel(id: WhisperKit.recommendedModels().default, displayName: WhisperKit.recommendedModels().default.capitalized, size: "", hfURL: "argmaxinc/whisperkit-coreml/\(WhisperKit.recommendedModels().default)", localURL: nil, icon: "waveform.badge.mic", modelType: .stt) ] // MARK: Model Management func resetState() { transcribeTask?.cancel() isRecording = false isTranscribing = false whisperKit?.audioProcessor.stopRecording() currentText = "" currentChunks = [:] pipelineStart = Double.greatestFiniteMagnitude firstTokenTime = Double.greatestFiniteMagnitude effectiveRealTimeFactor = 0 effectiveSpeedFactor = 0 totalInferenceTime = 0 tokensPerSecond = 0 currentLag = 0 currentFallbacks = 0 currentEncodingLoops = 0 currentDecodingLoops = 0 lastBufferSize = 0 lastConfirmedSegmentEndSeconds = 0 requiredSegmentsForConfirmation = 2 bufferEnergy = [] bufferSeconds = 0 confirmedSegments = [] unconfirmedSegments = [] eagerResults = [] prevResult = nil lastAgreedSeconds = 0.0 prevWords = [] lastAgreedWords = [] confirmedWords = [] confirmedText = "" hypothesisWords = [] hypothesisText = "" } func updateProgressBar(targetProgress: Float, maxTime: TimeInterval) async { let initialProgress = loadingProgressValue let decayConstant = -log(1 - targetProgress) / Float(maxTime) let startTime = Date() while true { let elapsedTime = Date().timeIntervalSince(startTime) // Break down the calculation let decayFactor = exp(-decayConstant * Float(elapsedTime)) let progressIncrement = (1 - initialProgress) * (1 - decayFactor) let currentProgress = initialProgress + progressIncrement await MainActor.run { loadingProgressValue = currentProgress } if currentProgress >= targetProgress { break } do { try await Task.sleep(nanoseconds: 100_000_000) } catch { break } } } func getComputeOptions() -> ModelComputeOptions { return ModelComputeOptions(audioEncoderCompute: encoderComputeUnits, textDecoderCompute: decoderComputeUnits) } func downloadModel(_ model: LocalModel, redownload: Bool = false) { guard let modelIndex = availableLocalModels.firstIndex(where: { $0.id == model.id }) else { return } // print("Selected Model: \(UserDefaults.standard.string(forKey: "selectedModel") ?? "nil")") // print(""" // Computing Options: // - Mel Spectrogram: \(getComputeOptions().melCompute.description) // - Audio Encoder: \(getComputeOptions().audioEncoderCompute.description) // - Text Decoder: \(getComputeOptions().textDecoderCompute.description) // - Prefill Data: \(getComputeOptions().prefillCompute.description) // """) whisperKit = nil Task { whisperKit = try await WhisperKit( computeOptions: getComputeOptions(), verbose: true, logLevel: .debug, prewarm: false, load: false, download: false ) guard let _ = whisperKit else { return } // Check if the model is available locally if localModels.contains(model.id) && !redownload { // Get local model folder URL from localModels await MainActor.run { loadingProgressValue = 1.0 } } else { // Download the model availableLocalModels[modelIndex].downloadState = .downloading(progress: 0) _ = try await WhisperKit.download(variant: model.id, from: repoName, progressCallback: { progress in DispatchQueue.main.async { self.availableLocalModels[modelIndex].downloadState = .downloading(progress: progress.fractionCompleted) } }) await MainActor.run { self.availableLocalModels[modelIndex].downloadState = .downloaded if !localModels.contains(model.id) { localModels.append(model.id) } fetchModels() } } } } func loadModel(_ model: String, redownload: Bool = false) { // print("Selected Model: \(UserDefaults.standard.string(forKey: "selectedModel") ?? "nil")") // print(""" // Computing Options: // - Mel Spectrogram: \(getComputeOptions().melCompute.description) // - Audio Encoder: \(getComputeOptions().audioEncoderCompute.description) // - Text Decoder: \(getComputeOptions().textDecoderCompute.description) // - Prefill Data: \(getComputeOptions().prefillCompute.description) // """) whisperKit = nil Task { whisperKit = try await WhisperKit( computeOptions: getComputeOptions(), verbose: true, logLevel: .debug, prewarm: false, load: false, download: false ) guard let whisperKit = whisperKit else { return } var folder: URL? // Check if the model is available locally if localModels.contains(model) && !redownload { // Get local model folder URL from localModels // TODO: Make this configurable in the UI folder = URL(fileURLWithPath: localModelPath).appendingPathComponent(model) } else { // Download the model folder = try await WhisperKit.download(variant: model, from: repoName, progressCallback: { progress in DispatchQueue.main.async { self.loadingProgressValue = Float(progress.fractionCompleted) * self.specializationProgressRatio self.modelState = .downloading } }) } await MainActor.run { loadingProgressValue = specializationProgressRatio modelState = .downloaded } if let modelFolder = folder { whisperKit.modelFolder = modelFolder await MainActor.run { // Set the loading progress to 90% of the way after prewarm loadingProgressValue = specializationProgressRatio modelState = .prewarming } let progressBarTask = Task { await updateProgressBar(targetProgress: 0.9, maxTime: 240) } // Prewarm models do { try await whisperKit.prewarmModels() progressBarTask.cancel() } catch { print("Error prewarming models, retrying: \(error.localizedDescription)") progressBarTask.cancel() if !redownload { loadModel(model, redownload: true) return } else { // Redownloading failed, error out modelState = .unloaded return } } await MainActor.run { // Set the loading progress to 90% of the way after prewarm loadingProgressValue = specializationProgressRatio + 0.9 * (1 - specializationProgressRatio) modelState = .loading } try await whisperKit.loadModels() await MainActor.run { if !localModels.contains(model) { localModels.append(model) } availableLanguages = Constants.languages.map { $0.key }.sorted() loadingProgressValue = 1.0 modelState = whisperKit.modelState } } } } func fetchModels() { availableModels = [] // First check what's already downloaded if let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first { let modelPath = documents.appendingPathComponent(modelStorage).path // Check if the directory exists if FileManager.default.fileExists(atPath: modelPath) { localModelPath = modelPath do { let downloadedModels = try FileManager.default.contentsOfDirectory(atPath: modelPath) for model in downloadedModels where !localModels.contains(model) { localModels.append(model) } } catch { print("Error enumerating files at \(modelPath): \(error.localizedDescription)") } } } localModels = WhisperKit.formatModelFiles(localModels) for availableLocalModel in availableLocalModels { if localModels.contains(availableLocalModel.id) { let fileSize = getDirectorySize(selectedModel: availableLocalModel.id) availableLocalModel.downloadState = .downloaded availableLocalModel.localURL = URL(fileURLWithPath: localModelPath).appendingPathComponent(availableLocalModel.id) availableLocalModel.size = fileSize } else { availableLocalModel.downloadState = .notDownloaded availableLocalModel.localURL = nil availableLocalModel.size = nil } } // for dwnModel in localModels { // if let modelIndex = availableLocalModels.firstIndex(where: { $0.id == dwnModel }) { // let fileSize = getDirectorySize(selectedModel: dwnModel) // availableLocalModels[modelIndex].downloadState = .downloaded // availableLocalModels[modelIndex].localURL = URL(fileURLWithPath: localModelPath).appendingPathComponent(dwnModel) // availableLocalModels[modelIndex].size = fileSize // } // } // for model in localModels { // if !availableModels.contains(model), // !disabledModels.contains(model) { // availableModels.append(model) // } // } // Task { // let remoteModels = try await WhisperKit.fetchAvailableModels(from: repoName) // for model in remoteModels { // if !availableModels.contains(model), // !disabledModels.contains(model) { // availableModels.append(model) // } // } // } } func getDirectorySize(selectedModel: String) -> String { var totalSize: Int64 = 0 if localModels.contains(selectedModel) { let modelFolder = URL(fileURLWithPath: localModelPath).appendingPathComponent(selectedModel) guard let enumerator = FileManager.default.enumerator(at: modelFolder, includingPropertiesForKeys: [.fileSizeKey, .isDirectoryKey]) else { print("Failed to create enumerator for \(modelFolder)") return "--" } for case let fileURL as URL in enumerator { do { let resourceValues = try fileURL.resourceValues(forKeys: [.fileSizeKey, .isDirectoryKey]) if let isDirectory = resourceValues.isDirectory, isDirectory { continue } if let fileSize = resourceValues.fileSize { totalSize += Int64(fileSize) } } catch { print("Error getting size of file \(fileURL): \(error)") } } } let formatter = ByteCountFormatter() formatter.allowedUnits = [.useGB, .useMB] formatter.countStyle = .file let sizeInBytes = Int(exactly: totalSize) ?? Int.max let formattedSize = formatter.string(fromByteCount: Int64(sizeInBytes)) return formattedSize } func getFileCreationDate(for selectedModel: String) -> String? { let filePath = URL(fileURLWithPath: localModelPath).appendingPathComponent(selectedModel) guard let attributes = try? FileManager.default.attributesOfItem(atPath: filePath.path()), let creationDate = attributes[.creationDate] as? Date else { return "--" } let formatter = RelativeDateTimeFormatter() formatter.unitsStyle = .full return formatter.localizedString(for: creationDate, relativeTo: Date()) } func deleteModel(selectedModel: String) { if localModels.contains(selectedModel) { let modelFolder = URL(fileURLWithPath: localModelPath).appendingPathComponent(selectedModel) do { try FileManager.default.removeItem(at: modelFolder) if let index = localModels.firstIndex(of: selectedModel) { localModels.remove(at: index) } modelState = .unloaded } catch { print("Error deleting model: \(error)") } } } func transcribeFile(path: String) { resetState() whisperKit?.audioProcessor = AudioProcessor() self.transcribeTask = Task { isTranscribing = true do { try await transcribeCurrentFile(path: path) } catch { print("File selection error: \(error.localizedDescription)") } isTranscribing = false } } func startRecording(_ loop: Bool, source: TranscriptionSource = .none) { transcriptionSource = source if let audioProcessor = whisperKit?.audioProcessor { Task(priority: .userInitiated) { guard await AudioProcessor.requestRecordPermission() else { print("Microphone access was not granted.") return } setupMicrophone() var deviceId: DeviceID? if self.selectedAudioInput != "None", let devices = self.audioDevices, let device = devices.first(where: { $0.name == selectedAudioInput }) { deviceId = device.id } // There is no built-in microphone if deviceId == nil { throw WhisperError.microphoneUnavailable() } try? audioProcessor.startRecordingLive(inputDeviceID: deviceId) { _ in DispatchQueue.main.async { [self] in self.bufferEnergy = self.whisperKit?.audioProcessor.relativeEnergy ?? [] bufferSeconds = Double(self.whisperKit?.audioProcessor.audioSamples.count ?? 0) / Double(WhisperKit.sampleRate) } } // Delay the timer start by 1 second isRecording = true isTranscribing = true if loop { realtimeLoop() } } } } func stopRecording(_ loop: Bool) { isRecording = false bufferSeconds = 0 stopRealtimeTranscription() isTranscriptionComplete = false if let audioProcessor = whisperKit?.audioProcessor { audioProcessor.stopRecording() } if !loop { self.transcribeTask = Task { isTranscribing = true do { try await transcribeCurrentBuffer() } catch { print("Error: \(error.localizedDescription)") } finalizeText() isTranscribing = false await MainActor.run { isTranscriptionComplete = true } } } else { finalizeText() isTranscriptionComplete = true } } func finalizeText() { // Finalize unconfirmed text Task { await MainActor.run { if hypothesisText != "" { confirmedText += hypothesisText hypothesisText = "" } if unconfirmedSegments.count > 0 { confirmedSegments.append(contentsOf: unconfirmedSegments) unconfirmedSegments = [] } } } } func setupMicrophone() { audioDevices = AudioProcessor.getAudioDevices() if let audioDevices { if audioDevices.isEmpty { // throw WhisperError.microphoneUnavailable() return } else { let device = audioDevices.first(where: { $0.name == selectedAudioInput }) if selectedAudioInput == "None" || device == nil { if let defaultDevice = AVCaptureDevice.default(for: .audio) { selectedAudioInput = defaultDevice.localizedName } else { selectedAudioInput = "None" } } else { selectedAudioInput = device!.name } } } } // MARK: - Transcribe Logic func transcribeCurrentFile(path: String) async throws { // Load and convert buffer in a limited scope let audioFileSamples = try await Task { try autoreleasepool { let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path) return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer) } }.value let transcription = try await transcribeAudioSamples(audioFileSamples) await MainActor.run { currentText = "" guard let segments = transcription?.segments else { return } self.tokensPerSecond = transcription?.timings.tokensPerSecond ?? 0 self.effectiveRealTimeFactor = transcription?.timings.realTimeFactor ?? 0 self.effectiveSpeedFactor = transcription?.timings.speedFactor ?? 0 self.currentEncodingLoops = Int(transcription?.timings.totalEncodingRuns ?? 0) self.firstTokenTime = transcription?.timings.firstTokenTime ?? 0 self.pipelineStart = transcription?.timings.pipelineStart ?? 0 self.currentLag = transcription?.timings.decodingLoop ?? 0 self.confirmedSegments = segments } } func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? { guard let whisperKit = whisperKit else { return nil } let languageCode = Constants.languages[selectedLanguage, default: Constants.defaultLanguageCode] let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate let seekClip: [Float] = [lastConfirmedSegmentEndSeconds] let options = DecodingOptions( verbose: true, task: task, language: languageCode, temperature: Float(temperatureStart), temperatureFallbackCount: Int(fallbackCount), sampleLength: Int(sampleLength), usePrefillPrompt: enablePromptPrefill, usePrefillCache: enableCachePrefill, skipSpecialTokens: !enableSpecialCharacters, withoutTimestamps: !enableTimestamps, wordTimestamps: true, clipTimestamps: seekClip, chunkingStrategy: chunkingStrategy ) // Early stopping checks let decodingCallback: ((TranscriptionProgress) -> Bool?) = { [self] (progress: TranscriptionProgress) in DispatchQueue.main.async { let fallbacks = Int(progress.timings.totalDecodingFallbacks) let chunkId = self.transcriptionMode == .streaming ? 0 : progress.windowId // First check if this is a new window for the same chunk, append if so var updatedChunk = (chunkText: [progress.text], fallbacks: fallbacks) if var currentChunk = self.currentChunks[chunkId], let previousChunkText = currentChunk.chunkText.last { if progress.text.count >= previousChunkText.count { // This is the same window of an existing chunk, so we just update the last value currentChunk.chunkText[currentChunk.chunkText.endIndex - 1] = progress.text updatedChunk = currentChunk } else { // This is either a new window or a fallback (only in streaming mode) if fallbacks == currentChunk.fallbacks && self.transcriptionMode == .streaming { // New window (since fallbacks havent changed) updatedChunk.chunkText = [updatedChunk.chunkText.first ?? "" + progress.text] } else { // Fallback, overwrite the previous bad text updatedChunk.chunkText[currentChunk.chunkText.endIndex - 1] = progress.text updatedChunk.fallbacks = fallbacks print("Fallback occured: \(fallbacks)") } } } // Set the new text for the chunk self.currentChunks[chunkId] = updatedChunk let joinedChunks = self.currentChunks.sorted { $0.key < $1.key }.flatMap { $0.value.chunkText }.joined(separator: "\n") self.currentText = joinedChunks self.currentFallbacks = fallbacks self.currentDecodingLoops += 1 } // Check early stopping let currentTokens = progress.tokens let checkWindow = Int(compressionCheckWindow) if currentTokens.count > checkWindow { let checkTokens: [Int] = currentTokens.suffix(checkWindow) let compressionRatio = compressionRatio(of: checkTokens) if compressionRatio > options.compressionRatioThreshold! { Logging.debug("Early stopping due to compression threshold") return false } } if progress.avgLogprob! < options.logProbThreshold! { Logging.debug("Early stopping due to logprob threshold") return false } return nil } let transcriptionResults: [TranscriptionResult] = try await whisperKit.transcribe( audioArray: samples, decodeOptions: options, callback: decodingCallback ) let mergedResults = mergeTranscriptionResults(transcriptionResults) return mergedResults } // MARK: Streaming Logic func realtimeLoop() { transcriptionTask = Task { while isRecording && isTranscribing { do { try await transcribeCurrentBuffer() } catch { print("Error: \(error.localizedDescription)") break } } } } func stopRealtimeTranscription() { isTranscribing = false transcriptionTask?.cancel() } func transcribeCurrentBuffer() async throws { guard let whisperKit = whisperKit else { return } // Retrieve the current audio buffer from the audio processor let currentBuffer = whisperKit.audioProcessor.audioSamples // Calculate the size and duration of the next buffer segment let nextBufferSize = currentBuffer.count - lastBufferSize let nextBufferSeconds = Float(nextBufferSize) / Float(WhisperKit.sampleRate) // Only run the transcribe if the next buffer has at least 1 second of audio guard nextBufferSeconds > 0.1 else { await MainActor.run { if currentText == "" { currentText = "Waiting for speech..." } } try await Task.sleep(nanoseconds: 100_000_000) // sleep for 100ms for next buffer return } if useVAD { let voiceDetected = AudioProcessor.isVoiceDetected( in: whisperKit.audioProcessor.relativeEnergy, nextBufferInSeconds: nextBufferSeconds, silenceThreshold: Float(silenceThreshold) ) // Only run the transcribe if the next buffer has voice guard voiceDetected else { await MainActor.run { if currentText == "" { currentText = "Waiting for speech..." } } // TODO: Implement silence buffer purging // if nextBufferSeconds > 30 { // // This is a completely silent segment of 30s, so we can purge the audio and confirm anything pending // lastConfirmedSegmentEndSeconds = 0 // whisperKit.audioProcessor.purgeAudioSamples(keepingLast: 2 * WhisperKit.sampleRate) // keep last 2s to include VAD overlap // currentBuffer = whisperKit.audioProcessor.audioSamples // lastBufferSize = 0 // confirmedSegments.append(contentsOf: unconfirmedSegments) // unconfirmedSegments = [] // } // Sleep for 100ms and check the next buffer try await Task.sleep(nanoseconds: 100_000_000) return } } // Store this for next iterations VAD lastBufferSize = currentBuffer.count if enableEagerDecoding && transcriptionMode == .streaming { // Run realtime transcribe using word timestamps for segmentation let transcription = try await transcribeEagerMode(Array(currentBuffer)) await MainActor.run { currentText = "" self.tokensPerSecond = transcription?.timings.tokensPerSecond ?? 0 self.firstTokenTime = transcription?.timings.firstTokenTime ?? 0 self.pipelineStart = transcription?.timings.pipelineStart ?? 0 self.currentLag = transcription?.timings.decodingLoop ?? 0 self.currentEncodingLoops = Int(transcription?.timings.totalEncodingRuns ?? 0) let totalAudio = Double(currentBuffer.count) / Double(WhisperKit.sampleRate) self.totalInferenceTime = transcription?.timings.fullPipeline ?? 0 self.effectiveRealTimeFactor = Double(totalInferenceTime) / totalAudio self.effectiveSpeedFactor = totalAudio / Double(totalInferenceTime) } } else { // Run realtime transcribe using timestamp tokens directly let transcription = try await transcribeAudioSamples(Array(currentBuffer)) // We need to run this next part on the main thread await MainActor.run { currentText = "" guard let segments = transcription?.segments else { return } self.tokensPerSecond = transcription?.timings.tokensPerSecond ?? 0 self.firstTokenTime = transcription?.timings.firstTokenTime ?? 0 self.pipelineStart = transcription?.timings.pipelineStart ?? 0 self.currentLag = transcription?.timings.decodingLoop ?? 0 self.currentEncodingLoops += Int(transcription?.timings.totalEncodingRuns ?? 0) let totalAudio = Double(currentBuffer.count) / Double(WhisperKit.sampleRate) self.totalInferenceTime += transcription?.timings.fullPipeline ?? 0 self.effectiveRealTimeFactor = Double(totalInferenceTime) / totalAudio self.effectiveSpeedFactor = totalAudio / Double(totalInferenceTime) // Logic for moving segments to confirmedSegments if segments.count > requiredSegmentsForConfirmation { // Calculate the number of segments to confirm let numberOfSegmentsToConfirm = segments.count - requiredSegmentsForConfirmation // Confirm the required number of segments let confirmedSegmentsArray = Array(segments.prefix(numberOfSegmentsToConfirm)) let remainingSegments = Array(segments.suffix(requiredSegmentsForConfirmation)) // Update lastConfirmedSegmentEnd based on the last confirmed segment if let lastConfirmedSegment = confirmedSegmentsArray.last, lastConfirmedSegment.end > lastConfirmedSegmentEndSeconds { lastConfirmedSegmentEndSeconds = lastConfirmedSegment.end // print("Last confirmed segment end: \(lastConfirmedSegmentEndSeconds)") // Add confirmed segments to the confirmedSegments array for segment in confirmedSegmentsArray { if !self.confirmedSegments.contains(segment: segment) { self.confirmedSegments.append(segment) } } } // Update transcriptions to reflect the remaining segments self.unconfirmedSegments = remainingSegments } else { // Handle the case where segments are fewer or equal to required self.unconfirmedSegments = segments } } } } func transcribeEagerMode(_ samples: [Float]) async throws -> TranscriptionResult? { guard let whisperKit = whisperKit else { return nil } guard whisperKit.textDecoder.supportsWordTimestamps else { confirmedText = "Eager mode requires word timestamps, which are not supported by the current model." return nil } let languageCode = Constants.languages[selectedLanguage, default: Constants.defaultLanguageCode] let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate print(selectedLanguage) print(languageCode) let options = DecodingOptions( verbose: true, task: task, language: languageCode, temperature: Float(temperatureStart), temperatureFallbackCount: Int(fallbackCount), sampleLength: Int(sampleLength), usePrefillPrompt: enablePromptPrefill, usePrefillCache: enableCachePrefill, skipSpecialTokens: !enableSpecialCharacters, withoutTimestamps: !enableTimestamps, wordTimestamps: true, // required for eager mode firstTokenLogProbThreshold: -1.5 // higher threshold to prevent fallbacks from running to often ) // Early stopping checks let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in DispatchQueue.main.async { let fallbacks = Int(progress.timings.totalDecodingFallbacks) if progress.text.count < self.currentText.count { if fallbacks == self.currentFallbacks { // self.unconfirmedText.append(currentText) } else { print("Fallback occured: \(fallbacks)") } } self.currentText = progress.text self.currentFallbacks = fallbacks self.currentDecodingLoops += 1 } // Check early stopping let currentTokens = progress.tokens let checkWindow = Int(self.compressionCheckWindow) if currentTokens.count > checkWindow { let checkTokens: [Int] = currentTokens.suffix(checkWindow) let compressionRatio = compressionRatio(of: checkTokens) if compressionRatio > options.compressionRatioThreshold! { Logging.debug("Early stopping due to compression threshold") return false } } if progress.avgLogprob! < options.logProbThreshold! { Logging.debug("Early stopping due to logprob threshold") return false } return nil } Logging.info("[EagerMode] \(lastAgreedSeconds)-\(Double(samples.count) / 16000.0) seconds") let streamingAudio = samples var streamOptions = options streamOptions.clipTimestamps = [lastAgreedSeconds] let lastAgreedTokens = lastAgreedWords.flatMap { $0.tokens } streamOptions.prefixTokens = lastAgreedTokens do { let transcription: TranscriptionResult? = try await whisperKit.transcribe(audioArray: streamingAudio, decodeOptions: streamOptions, callback: decodingCallback).first await MainActor.run { var skipAppend = false if let result = transcription { hypothesisWords = result.allWords.filter { $0.start >= lastAgreedSeconds } if let prevResult = prevResult { prevWords = prevResult.allWords.filter { $0.start >= lastAgreedSeconds } let commonPrefix = findLongestCommonPrefix(prevWords, hypothesisWords) Logging.info("[EagerMode] Prev \"\((prevWords.map { $0.word }).joined())\"") Logging.info("[EagerMode] Next \"\((hypothesisWords.map { $0.word }).joined())\"") Logging.info("[EagerMode] Found common prefix \"\((commonPrefix.map { $0.word }).joined())\"") if commonPrefix.count >= Int(tokenConfirmationsNeeded) { lastAgreedWords = commonPrefix.suffix(Int(tokenConfirmationsNeeded)) lastAgreedSeconds = lastAgreedWords.first!.start Logging.info("[EagerMode] Found new last agreed word \"\(lastAgreedWords.first!.word)\" at \(lastAgreedSeconds) seconds") confirmedWords.append(contentsOf: commonPrefix.prefix(commonPrefix.count - Int(tokenConfirmationsNeeded))) let currentWords = confirmedWords.map { $0.word }.joined() Logging.info("[EagerMode] Current: \(lastAgreedSeconds) -> \(Double(samples.count) / 16000.0) \(currentWords)") } else { Logging.info("[EagerMode] Using same last agreed time \(lastAgreedSeconds)") skipAppend = true } } prevResult = result } if !skipAppend { eagerResults.append(transcription) } } await MainActor.run { let finalWords = confirmedWords.map { $0.word }.joined() confirmedText = finalWords // Accept the final hypothesis because it is the last of the available audio let lastHypothesis = lastAgreedWords + findLongestDifferentSuffix(prevWords, hypothesisWords) hypothesisText = lastHypothesis.map { $0.word }.joined() } } catch { Logging.error("[EagerMode] Error: \(error)") finalizeText() } let mergedResult = mergeTranscriptionResults(eagerResults, confirmedWords: confirmedWords) return mergedResult } // Transcription methods public func getFullTranscript() -> String { finalizeText() let segments = confirmedSegments + unconfirmedSegments let transcript = formatSegments(segments, withTimestamps: false) transcriptionSource = .none return transcript.joined(separator: "\n") } }