in HuggingSnap/Views/VLMEvaluator.swift [133:224]
func generate(image: CIImage?, videoURL: URL?) async {
guard !running else { return }
running = true
self.output = ""
let orientedImage = image?.oriented(.right)
do {
let modelContainer = try await load()
let result = try await modelContainer.perform { context in
let images: [UserInput.Image] =
if let orientedImage {
[UserInput.Image.ciImage(orientedImage)]
} else {
[]
}
let videos: [UserInput.Video] =
if let videoURL {
[.url(videoURL)]
} else {
[]
}
let systemPrompt = videoURL != nil ? runtimeConfiguration.videoSystemPrompt : runtimeConfiguration.photoSystemPrompt
let userPrompt = await customUserInput.isEmpty ? (videoURL != nil ? runtimeConfiguration.videoUserPrompt : runtimeConfiguration.photoUserPrompt):customUserInput
// Note: the image order is different for smolvlm
let messages: [Message] = [
[
"role": "system",
"content": [
[
"type": "text",
"text": systemPrompt,
],
]
],
[
"role": "user",
"content": []
+ images.map { _ in
["type": "image"]
}
+ videos.map { _ in
["type": "video"]
}
+ [["type": "text", "text": userPrompt]]
]
]
let userInput = UserInput(messages: messages, images: images, videos: videos)
let input = try await context.processor.prepare(input: userInput)
let generationParameters = MLXLMCommon.GenerateParameters(
temperature: runtimeConfiguration.generationParameters.temperature,
topP: runtimeConfiguration.generationParameters.topP
)
return try MLXLMCommon.generate(
input: input,
parameters: generationParameters,
context: context
) { tokens in
// update the output -- this will make the view show the text as it generates
if tokens.count % displayEveryNTokens == 0 {
let text = context.tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = text
}
}
if tokens.count >= maxTokens {
return .stop
} else {
return .more
}
}
}
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
if result.output != self.output {
self.output = result.output
}
// print(self.output)
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
} catch {
output = "Failed: \(error)"
}
running = false
}