Diffusion/Common/State.swift (210 lines of code) (raw):
//
// State.swift
// Diffusion
//
// Created by Pedro Cuenca on 17/1/23.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Combine
import SwiftUI
import StableDiffusion
import CoreML
let DEFAULT_MODEL = ModelInfo.sd3
let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
enum GenerationState {
case startup
case running(StableDiffusionProgress?)
case complete(String, CGImage?, UInt32, TimeInterval?)
case userCanceled
case failed(Error)
}
typealias ComputeUnits = MLComputeUnits
/// Schedulers compatible with StableDiffusionPipeline. This is a local implementation of the StableDiffusionScheduler enum as a String represetation to allow for compliance with NSSecureCoding.
public enum StableDiffusionScheduler: String {
/// Scheduler that uses a pseudo-linear multi-step (PLMS) method
case pndmScheduler
/// Scheduler that uses a second order DPM-Solver++ algorithm
case dpmSolverMultistepScheduler
/// Scheduler for rectified flow based multimodal diffusion transformer models
case discreteFlowScheduler
func asStableDiffusionScheduler() -> StableDiffusion.StableDiffusionScheduler {
switch self {
case .pndmScheduler: return StableDiffusion.StableDiffusionScheduler.pndmScheduler
case .dpmSolverMultistepScheduler: return StableDiffusion.StableDiffusionScheduler.dpmSolverMultistepScheduler
case .discreteFlowScheduler: return StableDiffusion.StableDiffusionScheduler.discreteFlowScheduler
}
}
}
class GenerationContext: ObservableObject {
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
@Published var pipeline: Pipeline? = nil {
didSet {
if let pipeline = pipeline {
progressSubscriber = pipeline
.progressPublisher
.receive(on: DispatchQueue.main)
.sink { progress in
guard let progress = progress else { return }
self.updatePreviewIfNeeded(progress)
self.state = .running(progress)
}
}
}
}
@Published var state: GenerationState = .startup
@Published var positivePrompt = Settings.shared.prompt
@Published var negativePrompt = Settings.shared.negativePrompt
// FIXME: Double to support the slider component
@Published var steps: Double = Settings.shared.stepCount
@Published var numImages: Double = 1.0
@Published var seed: UInt32 = Settings.shared.seed
@Published var guidanceScale: Double = Settings.shared.guidanceScale
@Published var previews: Double = runningOnMac ? Settings.shared.previewCount : 0.0
@Published var disableSafety = false
@Published var previewImage: CGImage? = nil
@Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
private var progressSubscriber: Cancellable?
private func updatePreviewIfNeeded(_ progress: StableDiffusionProgress) {
if previews == 0 || progress.step == 0 {
previewImage = nil
}
if previews > 0, let newImage = progress.currentImages.first, newImage != nil {
previewImage = newImage
}
}
func generate() async throws -> GenerationResult {
guard let pipeline = pipeline else { throw "No pipeline" }
return try pipeline.generate(
prompt: positivePrompt,
negativePrompt: negativePrompt,
scheduler: scheduler,
numInferenceSteps: Int(steps),
seed: seed,
numPreviews: Int(previews),
guidanceScale: Float(guidanceScale),
disableSafety: disableSafety
)
}
func cancelGeneration() {
pipeline?.setCancelled()
}
}
class Settings {
static let shared = Settings()
let defaults = UserDefaults.standard
enum Keys: String {
case model
case safetyCheckerDisclaimer
case computeUnits
case prompt
case negativePrompt
case guidanceScale
case stepCount
case previewCount
case seed
}
private init() {
defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId,
Keys.safetyCheckerDisclaimer.rawValue: false,
Keys.computeUnits.rawValue: -1, // Use default
Keys.prompt.rawValue: DEFAULT_PROMPT,
Keys.negativePrompt.rawValue: "",
Keys.guidanceScale.rawValue: 7.5,
Keys.stepCount.rawValue: 25,
Keys.previewCount.rawValue: 5,
Keys.seed.rawValue: 0
])
}
var currentModel: ModelInfo {
set {
defaults.set(newValue.modelId, forKey: Keys.model.rawValue)
}
get {
guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL }
return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL
}
}
var prompt: String {
set {
defaults.set(newValue, forKey: Keys.prompt.rawValue)
}
get {
return defaults.string(forKey: Keys.prompt.rawValue) ?? DEFAULT_PROMPT
}
}
var negativePrompt: String {
set {
defaults.set(newValue, forKey: Keys.negativePrompt.rawValue)
}
get {
return defaults.string(forKey: Keys.negativePrompt.rawValue) ?? ""
}
}
var guidanceScale: Double {
set {
defaults.set(newValue, forKey: Keys.guidanceScale.rawValue)
}
get {
return defaults.double(forKey: Keys.guidanceScale.rawValue)
}
}
var stepCount: Double {
set {
defaults.set(newValue, forKey: Keys.stepCount.rawValue)
}
get {
return defaults.double(forKey: Keys.stepCount.rawValue)
}
}
var previewCount: Double {
set {
defaults.set(newValue, forKey: Keys.previewCount.rawValue)
}
get {
return defaults.double(forKey: Keys.previewCount.rawValue)
}
}
var seed: UInt32 {
set {
defaults.set(String(newValue), forKey: Keys.seed.rawValue)
}
get {
if let seedString = defaults.string(forKey: Keys.seed.rawValue), let seedValue = UInt32(seedString) {
return seedValue
}
return 0
}
}
var safetyCheckerDisclaimerShown: Bool {
set {
defaults.set(newValue, forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
get {
return defaults.bool(forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
}
/// Returns the option selected by the user, if overridden
/// `nil` means: guess best
var userSelectedComputeUnits: ComputeUnits? {
set {
// Any value other than the supported ones would cause `get` to return `nil`
defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue)
}
get {
let current = defaults.integer(forKey: Keys.computeUnits.rawValue)
guard current != -1 else { return nil }
return ComputeUnits(rawValue: current)
}
}
public func applicationSupportURL() -> URL {
let fileManager = FileManager.default
guard let appDirectoryURL = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first else {
// To ensure we don't return an optional - if the user domain application support cannot be accessed use the top level application support directory
return URL.applicationSupportDirectory
}
do {
// Create the application support directory if it doesn't exist
try fileManager.createDirectory(at: appDirectoryURL, withIntermediateDirectories: true, attributes: nil)
return appDirectoryURL
} catch {
print("Error creating application support directory: \(error)")
return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
}
}
func tempStorageURL() -> URL {
let tmpDir = applicationSupportURL().appendingPathComponent("hf-diffusion-tmp")
// Create directory if it doesn't exist
if !FileManager.default.fileExists(atPath: tmpDir.path) {
do {
try FileManager.default.createDirectory(at: tmpDir, withIntermediateDirectories: true, attributes: nil)
} catch {
print("Failed to create temporary directory: \(error)")
return FileManager.default.temporaryDirectory
}
}
return tmpDir
}
}