SAM2-Demo/ContentView.swift (313 lines of code) (raw):
import SwiftUI
import PhotosUI
import UniformTypeIdentifiers
import CoreML
import os
// TODO: Add reset, bounding box, and eraser
let logger = Logger(
subsystem:
"com.cyrilzakka.SAM2-Demo.ContentView",
category: "ContentView")
struct PointsOverlay: View {
@Binding var selectedPoints: [SAMPoint]
@Binding var selectedTool: SAMTool?
let imageSize: CGSize
var body: some View {
ForEach(selectedPoints, id: \.self) { point in
Circle()
.frame(width: 10, height: 10)
.foregroundStyle(point.category.color)
.position(point.coordinates.toSize(imageSize))
}
}
}
struct BoundingBoxesOverlay: View {
let boundingBoxes: [SAMBox]
let currentBox: SAMBox?
let imageSize: CGSize
var body: some View {
ForEach(boundingBoxes) { box in
BoundingBoxPath(box: box, imageSize: imageSize)
}
if let currentBox = currentBox {
BoundingBoxPath(box: currentBox, imageSize: imageSize)
}
}
}
struct BoundingBoxPath: View {
let box: SAMBox
let imageSize: CGSize
var body: some View {
Path { path in
path.move(to: box.startPoint.toSize(imageSize))
path.addLine(to: CGPoint(x: box.endPoint.x, y: box.startPoint.y).toSize(imageSize))
path.addLine(to: box.endPoint.toSize(imageSize))
path.addLine(to: CGPoint(x: box.startPoint.x, y: box.endPoint.y).toSize(imageSize))
path.closeSubpath()
}
.stroke(
box.category.color,
style: StrokeStyle(lineWidth: 2, dash: [5, 5])
)
}
}
struct SegmentationOverlay: View {
@Binding var segmentationImage: SAMSegmentation
let imageSize: CGSize
@State var counter: Int = 0
var origin: CGPoint = .zero
var shouldAnimate: Bool = false
var body: some View {
let nsImage = NSImage(cgImage: segmentationImage.cgImage, size: imageSize)
Image(nsImage: nsImage)
.resizable()
.scaledToFit()
.allowsHitTesting(false)
.frame(width: imageSize.width, height: imageSize.height)
.opacity(segmentationImage.isHidden ? 0:0.6)
.modifier(RippleEffect(at: CGPoint(x: segmentationImage.cgImage.width/2, y: segmentationImage.cgImage.height/2), trigger: counter))
.onAppear {
if shouldAnimate {
counter += 1
}
}
}
}
struct ContentView: View {
// ML Models
@StateObject private var sam2 = SAM2()
@State private var currentSegmentation: SAMSegmentation?
@State private var segmentationImages: [SAMSegmentation] = []
@State private var imageSize: CGSize = .zero
// File importer
@State private var imageURL: URL?
@State private var isImportingFromFiles: Bool = false
@State private var displayImage: NSImage?
// Mask exporter
@State private var exportURL: URL?
@State private var exportMaskToPNG: Bool = false
@State private var showInspector: Bool = true
@State private var selectedSegmentations = Set<SAMSegmentation.ID>()
// Photos Picker
@State private var isImportingFromPhotos: Bool = false
@State private var selectedItem: PhotosPickerItem?
@State private var error: Error?
// ML Model Properties
var tools: [SAMTool] = [pointTool, boundingBoxTool]
var categories: [SAMCategory] = [.foreground, .background]
@State private var selectedTool: SAMTool?
@State private var selectedCategory: SAMCategory?
@State private var selectedPoints: [SAMPoint] = []
@State private var boundingBoxes: [SAMBox] = []
@State private var currentBox: SAMBox?
@State private var originalSize: NSSize?
@State private var currentScale: CGFloat = 1.0
@State private var visibleRect: CGRect = .zero
var body: some View {
NavigationSplitView(sidebar: {
VStack {
LayerListView(segmentationImages: $segmentationImages, selectedSegmentations: $selectedSegmentations, currentSegmentation: $currentSegmentation)
Spacer()
Button(action: {
if let currentSegmentation = self.currentSegmentation {
self.segmentationImages.append(currentSegmentation)
self.reset()
}
}, label: {
Text("New Mask")
}).padding()
}
}, detail: {
ZStack {
ZoomableScrollView(visibleRect: $visibleRect) {
if let image = displayImage {
ImageView(image: image, currentScale: $currentScale, selectedTool: $selectedTool, selectedCategory: $selectedCategory, selectedPoints: $selectedPoints, boundingBoxes: $boundingBoxes, currentBox: $currentBox, segmentationImages: $segmentationImages, currentSegmentation: $currentSegmentation, imageSize: $imageSize, originalSize: $originalSize, sam2: sam2)
} else {
ContentUnavailableView("No Image Loaded", systemImage: "photo.fill.on.rectangle.fill", description: Text("Please import a photo to get started."))
}
}
VStack(spacing: 0) {
SubToolbar(selectedPoints: $selectedPoints, boundingBoxes: $boundingBoxes, segmentationImages: $segmentationImages, currentSegmentation: $currentSegmentation)
Spacer()
}
}
})
.inspector(isPresented: $showInspector, content: {
if selectedSegmentations.isEmpty {
ContentUnavailableView(label: {
Label(title: {
Text("No Mask Selected")
.font(.subheadline)
}, icon: {})
})
.inspectorColumnWidth(min: 200, ideal: 200, max: 200)
} else {
MaskEditor(exportMaskToPNG: $exportMaskToPNG, segmentationImages: $segmentationImages, selectedSegmentations: $selectedSegmentations, currentSegmentation: $currentSegmentation)
.inspectorColumnWidth(min: 200, ideal: 200, max: 200)
.toolbar {
Spacer()
Button {
showInspector.toggle()
} label: {
Label("Toggle Inspector", systemImage: "sidebar.trailing")
}
}
}
})
.toolbar {
// Tools
ToolbarItemGroup(placement: .principal) {
Picker(selection: $selectedTool, content: {
ForEach(tools, id: \.self) { tool in
Label(tool.name, systemImage: tool.iconName)
.tag(tool)
.labelStyle(.titleAndIcon)
}
}, label: {
Label("Tools", systemImage: "pencil.and.ruler")
})
.pickerStyle(.menu)
Picker(selection: $selectedCategory, content: {
ForEach(categories, id: \.self) { cat in
Label(cat.name, systemImage: cat.iconName)
.tag(cat)
.labelStyle(.titleAndIcon)
}
}, label: {
Label("Tools", systemImage: "pencil.and.ruler")
})
.pickerStyle(.menu)
}
// Import
ToolbarItemGroup {
Menu {
Button(action: {
isImportingFromPhotos = true
}, label: {
Label("From Photos", systemImage: "photo.on.rectangle.angled.fill")
})
Button(action: {
isImportingFromFiles = true
}, label: {
Label("From Files", systemImage: "folder.fill")
})
} label: {
Label("Import", systemImage: "photo.badge.plus")
}
}
}
.onAppear {
if selectedTool == nil {
selectedTool = tools[0]
}
if selectedCategory == nil {
selectedCategory = categories.first
}
}
// MARK: - Image encoding
.onChange(of: displayImage) {
segmentationImages = []
self.reset()
Task {
if let displayImage, let pixelBuffer = displayImage.pixelBuffer(width: 1024, height: 1024) {
originalSize = displayImage.size
do {
try await sam2.getImageEncoding(from: pixelBuffer)
} catch {
self.error = error
}
}
}
}
// MARK: - Photos Importer
.photosPicker(isPresented: $isImportingFromPhotos, selection: $selectedItem, matching: .any(of: [.images, .screenshots, .livePhotos]))
.onChange(of: selectedItem) {
Task {
if let loadedData = try? await
selectedItem?.loadTransferable(type: Data.self) {
DispatchQueue.main.async {
selectedPoints.removeAll()
displayImage = NSImage(data: loadedData)
}
} else {
logger.error("Error loading image from Photos.")
}
}
}
// MARK: - File Importer
.fileImporter(isPresented: $isImportingFromFiles,
allowedContentTypes: [.image]) { result in
switch result {
case .success(let file):
self.selectedItem = nil
self.selectedPoints.removeAll()
self.imageURL = file
loadImage(from: file)
case .failure(let error):
logger.error("File import error: \(error.localizedDescription)")
self.error = error
}
}
// MARK: - File exporter
.fileExporter(
isPresented: $exportMaskToPNG,
document: DirectoryDocument(initialContentType: .folder),
contentType: .folder,
defaultFilename: "Segmentations"
) { result in
if case .success(let url) = result {
exportURL = url
var selectedToExport = segmentationImages.filter { segmentation in
selectedSegmentations.contains(segmentation.id)
}
if let currentSegmentation {
selectedToExport.append(currentSegmentation)
}
exportSegmentations(selectedToExport, to: url)
}
}
}
// MARK: - Private Methods
private func loadImage(from url: URL) {
guard url.startAccessingSecurityScopedResource() else {
logger.error("Failed to access the file. Security-scoped resource access denied.")
return
}
defer { url.stopAccessingSecurityScopedResource() }
do {
let imageData = try Data(contentsOf: url)
if let image = NSImage(data: imageData) {
DispatchQueue.main.async {
self.displayImage = image
}
} else {
logger.error("Failed to create NSImage from file data")
}
} catch {
logger.error("Error loading image data: \(error.localizedDescription)")
self.error = error
}
}
func exportSegmentations(_ segmentations: [SAMSegmentation], to directory: URL) {
let fileManager = FileManager.default
do {
try fileManager.createDirectory(at: directory, withIntermediateDirectories: true, attributes: nil)
for (index, segmentation) in segmentations.enumerated() {
let filename = "segmentation_\(index + 1).png"
let fileURL = directory.appendingPathComponent(filename)
if let destination = CGImageDestinationCreateWithURL(fileURL as CFURL, UTType.png.identifier as CFString, 1, nil) {
CGImageDestinationAddImage(destination, segmentation.cgImage, nil)
if CGImageDestinationFinalize(destination) {
print("Saved segmentation \(index + 1) to \(fileURL.path)")
} else {
print("Failed to save segmentation \(index + 1)")
}
}
}
} catch {
print("Error creating directory: \(error.localizedDescription)")
}
}
private func reset() {
selectedPoints = []
boundingBoxes = []
currentBox = nil
currentSegmentation = nil
}
}
struct SizePreferenceKey: PreferenceKey {
static var defaultValue: CGSize = .zero
static func reduce(value: inout CGSize, nextValue: () -> CGSize) {
value = nextValue()
}
}
#Preview {
ContentView()
}