Diffusion/Common/Pipeline/Pipeline.swift (107 lines of code) (raw):
//
// Pipeline.swift
// Diffusion
//
// Created by Pedro Cuenca on December 2022.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Foundation
import CoreML
import Combine
import StableDiffusion
struct StableDiffusionProgress {
var progress: StableDiffusionPipeline.Progress
var step: Int { progress.step }
var stepCount: Int { progress.stepCount }
var currentImages: [CGImage?]
init(progress: StableDiffusionPipeline.Progress, previewIndices: [Bool]) {
self.progress = progress
self.currentImages = [nil]
// Since currentImages is a computed property, only access the preview image if necessary
if progress.step < previewIndices.count, previewIndices[progress.step] {
self.currentImages = progress.currentImages
}
}
}
struct GenerationResult {
var image: CGImage?
var lastSeed: UInt32
var interval: TimeInterval?
var userCanceled: Bool
var itsPerSecond: Double?
}
class Pipeline {
let pipeline: StableDiffusionPipelineProtocol
let maxSeed: UInt32
var isXL: Bool {
if #available(macOS 14.0, iOS 17.0, *) {
return (pipeline as? StableDiffusionXLPipeline) != nil
}
return false
}
var isSD3: Bool {
if #available(macOS 14.0, iOS 17.0, *) {
return (pipeline as? StableDiffusion3Pipeline) != nil
}
return false
}
var progress: StableDiffusionProgress? = nil {
didSet {
progressPublisher.value = progress
}
}
lazy private(set) var progressPublisher: CurrentValueSubject<StableDiffusionProgress?, Never> = CurrentValueSubject(progress)
private var canceled = false
init(_ pipeline: StableDiffusionPipelineProtocol, maxSeed: UInt32 = UInt32.max) {
self.pipeline = pipeline
self.maxSeed = maxSeed
}
func generate(
prompt: String,
negativePrompt: String = "",
scheduler: StableDiffusionScheduler,
numInferenceSteps stepCount: Int = 50,
seed: UInt32 = 0,
numPreviews previewCount: Int = 5,
guidanceScale: Float = 7.5,
disableSafety: Bool = false
) throws -> GenerationResult {
let beginDate = Date()
canceled = false
let theSeed = seed > 0 ? seed : UInt32.random(in: 1...maxSeed)
let sampleTimer = SampleTimer()
sampleTimer.start()
var config = StableDiffusionPipeline.Configuration(prompt: prompt)
config.negativePrompt = negativePrompt
config.stepCount = stepCount
config.seed = theSeed
config.guidanceScale = guidanceScale
config.disableSafety = disableSafety
config.schedulerType = scheduler.asStableDiffusionScheduler()
config.useDenoisedIntermediates = true
if isXL {
config.encoderScaleFactor = 0.13025
config.decoderScaleFactor = 0.13025
config.schedulerTimestepSpacing = .karras
}
if isSD3 {
config.encoderScaleFactor = 1.5305
config.decoderScaleFactor = 1.5305
config.decoderShiftFactor = 0.0609
config.schedulerTimestepShift = 3.0
}
// Evenly distribute previews based on inference steps
let previewIndices = previewIndices(stepCount, previewCount)
let images = try pipeline.generateImages(configuration: config) { progress in
sampleTimer.stop()
handleProgress(StableDiffusionProgress(progress: progress,
previewIndices: previewIndices),
sampleTimer: sampleTimer)
if progress.stepCount != progress.step {
sampleTimer.start()
}
return !canceled
}
let interval = Date().timeIntervalSince(beginDate)
print("Got images: \(images) in \(interval)")
// Unwrap the 1 image we asked for, nil means safety checker triggered
let image = images.compactMap({ $0 }).first
return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled, itsPerSecond: 1.0/sampleTimer.median)
}
func handleProgress(_ progress: StableDiffusionProgress, sampleTimer: SampleTimer) {
self.progress = progress
}
func setCancelled() {
canceled = true
}
}