SAM2-Demo/Views/ImageView.swift (126 lines of code) (raw):
//
// ImageView.swift
// SAM2-Demo
//
// Created by Cyril Zakka on 9/8/24.
//
import SwiftUI
struct ImageView: View {
let image: NSImage
@Binding var currentScale: CGFloat
@Binding var selectedTool: SAMTool?
@Binding var selectedCategory: SAMCategory?
@Binding var selectedPoints: [SAMPoint]
@Binding var boundingBoxes: [SAMBox]
@Binding var currentBox: SAMBox?
@Binding var segmentationImages: [SAMSegmentation]
@Binding var currentSegmentation: SAMSegmentation?
@Binding var imageSize: CGSize
@Binding var originalSize: NSSize?
@State var animationPoint: CGPoint = .zero
@ObservedObject var sam2: SAM2
@State private var error: Error?
var pointSequence: [SAMPoint] {
boundingBoxes.flatMap { $0.points } + selectedPoints
}
var body: some View {
Image(nsImage: image)
.resizable()
.aspectRatio(contentMode: .fit)
.scaleEffect(currentScale)
.onTapGesture(coordinateSpace: .local) { handleTap(at: $0) }
.gesture(boundingBoxGesture)
.onHover { changeCursorAppearance(is: $0) }
.background(GeometryReader { geometry in
Color.clear.preference(key: SizePreferenceKey.self, value: geometry.size)
})
.onPreferenceChange(SizePreferenceKey.self) { imageSize = $0 }
.onChange(of: selectedPoints.count, {
if !selectedPoints.isEmpty {
performForwardPass()
}
})
.onChange(of: boundingBoxes.count, {
if !boundingBoxes.isEmpty {
performForwardPass()
}
})
.overlay {
PointsOverlay(selectedPoints: $selectedPoints, selectedTool: $selectedTool, imageSize: imageSize)
BoundingBoxesOverlay(boundingBoxes: boundingBoxes, currentBox: currentBox, imageSize: imageSize)
if !segmentationImages.isEmpty {
ForEach(Array(segmentationImages.enumerated()), id: \.element.id) { index, segmentation in
SegmentationOverlay(segmentationImage: $segmentationImages[index], imageSize: imageSize, shouldAnimate: false)
.zIndex(Double (segmentationImages.count - index))
}
}
if let currentSegmentation = currentSegmentation {
SegmentationOverlay(segmentationImage: .constant(currentSegmentation), imageSize: imageSize, origin: animationPoint, shouldAnimate: true)
.zIndex(Double(segmentationImages.count + 1))
}
}
}
private func changeCursorAppearance(is inside: Bool) {
if inside {
if selectedTool == pointTool {
NSCursor.pointingHand.push()
} else if selectedTool == boundingBoxTool {
NSCursor.crosshair.push()
}
} else {
NSCursor.pop()
}
}
private var boundingBoxGesture: some Gesture {
DragGesture(minimumDistance: 0)
.onChanged { value in
guard selectedTool == boundingBoxTool else { return }
if currentBox == nil {
currentBox = SAMBox(startPoint: value.startLocation.fromSize(imageSize), endPoint: value.location.fromSize(imageSize), category: selectedCategory!)
} else {
currentBox?.endPoint = value.location.fromSize(imageSize)
}
}
.onEnded { value in
guard selectedTool == boundingBoxTool else { return }
if let box = currentBox {
boundingBoxes.append(box)
animationPoint = box.midpoint.toSize(imageSize)
currentBox = nil
}
}
}
private func handleTap(at location: CGPoint) {
if selectedTool == pointTool {
placePoint(at: location)
animationPoint = location
}
}
private func placePoint(at coordinates: CGPoint) {
let samPoint = SAMPoint(coordinates: coordinates.fromSize(imageSize), category: selectedCategory!)
self.selectedPoints.append(samPoint)
}
private func performForwardPass() {
Task {
do {
try await sam2.getPromptEncoding(from: pointSequence, with: imageSize)
if let mask = try await sam2.getMask(for: originalSize ?? .zero) {
DispatchQueue.main.async {
let colorSet = self.segmentationImages.map { $0.tintColor };
let furthestColor = furthestColor(from: colorSet, among: SAMSegmentation.candidateColors)
let segmentationNumber = segmentationImages.count
let segmentationOverlay = SAMSegmentation(image: mask, tintColor: furthestColor, title: "Untitled \(segmentationNumber + 1)")
self.currentSegmentation = segmentationOverlay
}
}
} catch {
self.error = error
}
}
}
}
#Preview {
ContentView()
}
extension CGPoint {
func fromSize(_ size: CGSize) -> CGPoint {
CGPoint(x: x / size.width, y: y / size.height)
}
func toSize(_ size: CGSize) -> CGPoint {
CGPoint(x: x * size.width, y: y * size.height)
}
}