SAM2-Demo/Common/SAM2.swift (187 lines of code) (raw):
//
// SAM2.swift
// SAM2-Demo
//
// Created by Cyril Zakka on 8/20/24.
//
import SwiftUI
import CoreML
import CoreImage
import CoreImage.CIFilterBuiltins
import Combine
import UniformTypeIdentifiers
@MainActor
class SAM2: ObservableObject {
@Published var imageEncodings: SAM2_1SmallImageEncoderFLOAT16Output?
@Published var promptEncodings: SAM2_1SmallPromptEncoderFLOAT16Output?
@Published private(set) var initializationTime: TimeInterval?
@Published private(set) var initialized: Bool?
private var imageEncoderModel: SAM2_1SmallImageEncoderFLOAT16?
private var promptEncoderModel: SAM2_1SmallPromptEncoderFLOAT16?
private var maskDecoderModel: SAM2_1SmallMaskDecoderFLOAT16?
// TODO: examine model inputs instead
var inputSize: CGSize { CGSize(width: 1024, height: 1024) }
var width: CGFloat { inputSize.width }
var height: CGFloat { inputSize.height }
init() {
Task {
await loadModels()
}
}
private func loadModels() async {
let startTime = CFAbsoluteTimeGetCurrent()
do {
let configuration = MLModelConfiguration()
configuration.computeUnits = .cpuAndGPU
let (imageEncoder, promptEncoder, maskDecoder) = try await Task.detached(priority: .userInitiated) {
let imageEncoder = try SAM2_1SmallImageEncoderFLOAT16(configuration: configuration)
let promptEncoder = try SAM2_1SmallPromptEncoderFLOAT16(configuration: configuration)
let maskDecoder = try SAM2_1SmallMaskDecoderFLOAT16(configuration: configuration)
return (imageEncoder, promptEncoder, maskDecoder)
}.value
let endTime = CFAbsoluteTimeGetCurrent()
self.initializationTime = endTime - startTime
self.initialized = true
self.imageEncoderModel = imageEncoder
self.promptEncoderModel = promptEncoder
self.maskDecoderModel = maskDecoder
print("Initialized models in \(String(format: "%.4f", self.initializationTime!)) seconds")
} catch {
print("Failed to initialize models: \(error)")
self.initializationTime = nil
self.initialized = false
}
}
// Convenience for use in the CLI
private var modelLoading: AnyCancellable?
func ensureModelsAreLoaded() async throws -> SAM2 {
let _ = try await withCheckedThrowingContinuation { continuation in
modelLoading = self.$initialized.sink { newValue in
if let initialized = newValue {
if initialized {
continuation.resume(returning: self)
} else {
continuation.resume(throwing: SAM2Error.modelNotLoaded)
}
}
}
}
return self
}
static func load() async throws -> SAM2 {
try await SAM2().ensureModelsAreLoaded()
}
func getImageEncoding(from pixelBuffer: CVPixelBuffer) async throws {
guard let model = imageEncoderModel else {
throw SAM2Error.modelNotLoaded
}
let encoding = try model.prediction(image: pixelBuffer)
self.imageEncodings = encoding
}
func getImageEncoding(from url: URL) async throws {
guard let model = imageEncoderModel else {
throw SAM2Error.modelNotLoaded
}
let inputs = try SAM2_1SmallImageEncoderFLOAT16Input(imageAt: url)
let encoding = try await model.prediction(input: inputs)
self.imageEncodings = encoding
}
func getPromptEncoding(from allPoints: [SAMPoint], with size: CGSize) async throws {
guard let model = promptEncoderModel else {
throw SAM2Error.modelNotLoaded
}
let transformedCoords = try transformCoords(allPoints.map { $0.coordinates }, normalize: false, origHW: size)
// Create MLFeatureProvider with the required input format
let pointsMultiArray = try MLMultiArray(shape: [1, NSNumber(value: allPoints.count), 2], dataType: .float32)
let labelsMultiArray = try MLMultiArray(shape: [1, NSNumber(value: allPoints.count)], dataType: .int32)
for (index, point) in transformedCoords.enumerated() {
pointsMultiArray[[0, index, 0] as [NSNumber]] = NSNumber(value: Float(point.x))
pointsMultiArray[[0, index, 1] as [NSNumber]] = NSNumber(value: Float(point.y))
labelsMultiArray[[0, index] as [NSNumber]] = NSNumber(value: allPoints[index].category.type.rawValue)
}
let encoding = try model.prediction(points: pointsMultiArray, labels: labelsMultiArray)
self.promptEncodings = encoding
}
func bestMask(for output: SAM2_1SmallMaskDecoderFLOAT16Output) -> MLMultiArray {
if #available(macOS 15.0, *) {
let scores = output.scoresShapedArray.scalars
let argmax = scores.firstIndex(of: scores.max() ?? 0) ?? 0
return MLMultiArray(output.low_res_masksShapedArray[0, argmax])
} else {
// Convert scores to float32 for compatibility with macOS < 15,
// plus ugly loop copy (could do some memcpys)
let scores = output.scores
let floatScores = (0..<scores.count).map { scores[$0].floatValue }
let argmax = floatScores.firstIndex(of: floatScores.max() ?? 0) ?? 0
let allMasks = output.low_res_masks
let (h, w) = (allMasks.shape[2], allMasks.shape[3])
let slice = try! MLMultiArray(shape: [h, w], dataType: allMasks.dataType)
for i in 0..<h.intValue {
for j in 0..<w.intValue {
let position = [0, argmax, i, j] as [NSNumber]
slice[[i as NSNumber, j as NSNumber]] = allMasks[position]
}
}
return slice
}
}
func getMask(for original_size: CGSize) async throws -> CIImage? {
guard let model = maskDecoderModel else {
throw SAM2Error.modelNotLoaded
}
if let image_embedding = self.imageEncodings?.image_embedding,
let feats0 = self.imageEncodings?.feats_s0,
let feats1 = self.imageEncodings?.feats_s1,
let sparse_embedding = self.promptEncodings?.sparse_embeddings,
let dense_embedding = self.promptEncodings?.dense_embeddings {
let output = try model.prediction(image_embedding: image_embedding, sparse_embedding: sparse_embedding, dense_embedding: dense_embedding, feats_s0: feats0, feats_s1: feats1)
// Extract best mask and ignore the others
let lowFeatureMask = bestMask(for: output)
// TODO: optimization
// Preserve range for upsampling
var minValue: Double = 9999
var maxValue: Double = -9999
for i in 0..<lowFeatureMask.count {
let v = lowFeatureMask[i].doubleValue
if v > maxValue { maxValue = v }
if v < minValue { minValue = v }
}
let threshold = -minValue / (maxValue - minValue)
// Resize first, then threshold
if let maskcgImage = lowFeatureMask.cgImage(min: minValue, max: maxValue) {
let ciImage = CIImage(cgImage: maskcgImage, options: [.colorSpace: NSNull()])
let resizedImage = try resizeImage(ciImage, to: original_size, applyingThreshold: Float(threshold))
return resizedImage?.maskedToAlpha()?.samTinted()
}
}
return nil
}
private func transformCoords(_ coords: [CGPoint], normalize: Bool = false, origHW: CGSize) throws -> [CGPoint] {
guard normalize else {
return coords.map { CGPoint(x: $0.x * width, y: $0.y * height) }
}
let w = origHW.width
let h = origHW.height
return coords.map { coord in
let normalizedX = coord.x / w
let normalizedY = coord.y / h
return CGPoint(x: normalizedX * width, y: normalizedY * height)
}
}
private func resizeImage(_ image: CIImage, to size: CGSize, applyingThreshold threshold: Float = 1) throws -> CIImage? {
let scale = CGAffineTransform(scaleX: size.width / image.extent.width,
y: size.height / image.extent.height)
return image.transformed(by: scale).applyingThreshold(threshold)
}
}
extension CIImage {
/// This is only appropriate for grayscale mask images (our case). CIColorMatrix can be used more generally.
func maskedToAlpha() -> CIImage? {
let filter = CIFilter.maskToAlpha()
filter.inputImage = self
return filter.outputImage
}
func samTinted() -> CIImage? {
let filter = CIFilter.colorMatrix()
filter.rVector = CIVector(x: 30/255, y: 0, z: 0, w: 1)
filter.gVector = CIVector(x: 0, y: 144/255, z: 0, w: 1)
filter.bVector = CIVector(x: 0, y: 0, z: 1, w: 1)
filter.biasVector = CIVector(x: -1, y: -1, z: -1, w: 0)
filter.inputImage = self
return filter.outputImage?.cropped(to: self.extent)
}
}
enum SAM2Error: Error {
case modelNotLoaded
case pixelBufferCreationFailed
case imageResizingFailed
}
@discardableResult func writeCGImage(_ image: CGImage, to destinationURL: URL) -> Bool {
guard let destination = CGImageDestinationCreateWithURL(destinationURL as CFURL, UTType.png.identifier as CFString, 1, nil) else { return false }
CGImageDestinationAddImage(destination, image, nil)
return CGImageDestinationFinalize(destination)
}