func getMask()

in SAM2-Demo/Common/SAM2.swift [152:186]


    func getMask(for original_size: CGSize) async throws -> CIImage? {
        guard let model = maskDecoderModel else {
            throw SAM2Error.modelNotLoaded
        }
        
        if let image_embedding = self.imageEncodings?.image_embedding,
           let feats0 = self.imageEncodings?.feats_s0,
           let feats1 = self.imageEncodings?.feats_s1,
           let sparse_embedding = self.promptEncodings?.sparse_embeddings,
           let dense_embedding = self.promptEncodings?.dense_embeddings {
            let output = try model.prediction(image_embedding: image_embedding, sparse_embedding: sparse_embedding, dense_embedding: dense_embedding, feats_s0: feats0, feats_s1: feats1)

            // Extract best mask and ignore the others
            let lowFeatureMask = bestMask(for: output)

            // TODO: optimization
            // Preserve range for upsampling
            var minValue: Double = 9999
            var maxValue: Double = -9999
            for i in 0..<lowFeatureMask.count {
                let v = lowFeatureMask[i].doubleValue
                if v > maxValue { maxValue = v }
                if v < minValue { minValue = v }
            }
            let threshold = -minValue / (maxValue - minValue)

            // Resize first, then threshold
            if let maskcgImage = lowFeatureMask.cgImage(min: minValue, max: maxValue) {
                let ciImage = CIImage(cgImage: maskcgImage, options: [.colorSpace: NSNull()])
                let resizedImage = try resizeImage(ciImage, to: original_size, applyingThreshold: Float(threshold))
                return resizedImage?.maskedToAlpha()?.samTinted()
            }
        }
        return nil
    }