sam2-cli/MainCommand.swift (86 lines of code) (raw):

import ArgumentParser import CoreImage import CoreML import ImageIO import UniformTypeIdentifiers import Combine let context = CIContext(options: [.outputColorSpace: NSNull()]) enum PointType: Int, ExpressibleByArgument { case background = 0 case foreground = 1 var asCategory: SAMCategory { switch self { case .background: return SAMCategory.background case .foreground: return SAMCategory.foreground } } } @main struct MainCommand: AsyncParsableCommand { static let configuration = CommandConfiguration( commandName: "sam2-cli", abstract: "Perform segmentation using the SAM v2 model." ) @Option(name: .shortAndLong, help: "The input image file.") var input: String // TODO: multiple points @Option(name: .shortAndLong, parsing: .upToNextOption, help: "List of input coordinates in format 'x,y'. Coordinates are relative to the input image size. Separate multiple entries with spaces, but don't use spaces between the coordinates.") var points: [CGPoint] @Option(name: .shortAndLong, parsing: .upToNextOption, help: "Point types that correspond to the input points. Use as many as points, 0 for background and 1 for foreground.") var types: [PointType] @Option(name: .shortAndLong, help: "The output PNG image file, showing the segmentation map overlaid on top of the original image.") var output: String @Option(name: [.long, .customShort("k")], help: "The output file name for the segmentation mask.") var mask: String? = nil @MainActor mutating func run() async throws { // TODO: specify directory with loadable .mlpackages instead let sam = try await SAM2.load() print("Models loaded in: \(String(describing: sam.initializationTime))") let targetSize = sam.inputSize // Load the input image guard let inputImage = CIImage(contentsOf: URL(filePath: input), options: [.colorSpace: NSNull()]) else { print("Failed to load image.") throw ExitCode(EXIT_FAILURE) } print("Original image size \(inputImage.extent)") // Resize the image to match the model's expected input let resizedImage = inputImage.resized(to: targetSize) // Convert to a pixel buffer guard let pixelBuffer = context.render(resizedImage, pixelFormat: kCVPixelFormatType_32ARGB) else { print("Failed to create pixel buffer for input image.") throw ExitCode(EXIT_FAILURE) } // Execute the model let clock = ContinuousClock() let start = clock.now try await sam.getImageEncoding(from: pixelBuffer) let duration = clock.now - start print("Image encoding took \(duration.formatted(.units(allowed: [.seconds, .milliseconds])))") let startMask = clock.now let pointSequence = zip(points, types).map { point, type in SAMPoint(coordinates:point, category:type.asCategory) } try await sam.getPromptEncoding(from: pointSequence, with: inputImage.extent.size) guard let maskImage = try await sam.getMask(for: inputImage.extent.size) else { throw ExitCode(EXIT_FAILURE) } let maskDuration = clock.now - startMask print("Prompt encoding and mask generation took \(maskDuration.formatted(.units(allowed: [.seconds, .milliseconds])))") // Write masks if let mask = mask { context.writePNG(maskImage, to: URL(filePath: mask)) } // Overlay over original and save guard let outputImage = maskImage.withAlpha(0.6)?.composited(over: inputImage) else { print("Failed to blend mask.") throw ExitCode(EXIT_FAILURE) } context.writePNG(outputImage, to: URL(filePath: output)) } } extension CGPoint: ExpressibleByArgument { public init?(argument: String) { let components = argument.split(separator: ",").map(String.init) guard components.count == 2, let x = Double(components[0]), let y = Double(components[1]) else { return nil } self.init(x: x, y: y) } }