coreml/export.py (422 lines of code) (raw):

import argparse import os import enum from typing import List, Optional, Tuple import ast import torch import numpy as np from PIL import Image from PIL.Image import Resampling import coremltools as ct from coremltools.converters.mil._deployment_compatibility import AvailableTarget from coremltools import ComputeUnit from coremltools.converters.mil.mil.passes.defs.quantization import ComputePrecision from coremltools.converters.mil import register_torch_op from coremltools.converters.mil.mil import Builder as mb from sam2.sam2_image_predictor import SAM2ImagePredictor class SAM2Variant(enum.Enum): Tiny = "tiny" Small = "small" BasePlus = "base-plus" Large = "large" def fmt(self): if self == SAM2Variant.BasePlus: return "BasePlus" return self.value.capitalize() SAM2_HW = (1024, 1024) def parse_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--output-dir", type=str, default=".", help="Provide location to save exported models.", ) parser.add_argument( "--variant", type=lambda x: getattr(SAM2Variant, x), choices=[variant for variant in SAM2Variant], default=SAM2Variant.Small, help="SAM2 variant to export.", ) parser.add_argument( "--points", type=str, help="List of 2D points, e.g., '[[10,20], [30,40]]'", ) parser.add_argument( "--boxes", type=str, help="List of 2D bounding boxes, e.g., '[[10,20,30,40], [50,60,70,80]]'", ) parser.add_argument( "--labels", type=str, help="List of binary labels for each points entry, denoting foreground (1) or background (0).", ) parser.add_argument( "--min-deployment-target", type=lambda x: getattr(AvailableTarget, x), choices=[target for target in AvailableTarget], default=AvailableTarget.iOS17, help="Minimum deployment target for CoreML model.", ) parser.add_argument( "--compute-units", type=lambda x: getattr(ComputeUnit, x), choices=[cu for cu in ComputeUnit], default=ComputeUnit.ALL, help="Which compute units to target for CoreML model.", ) parser.add_argument( "--precision", type=lambda x: getattr(ComputePrecision, x), choices=[p for p in ComputePrecision], default=ComputePrecision.FLOAT16, help="Precision to use for quantization.", ) return parser @register_torch_op def upsample_bicubic2d(context, node): x = context[node.inputs[0]] output_size = context[node.inputs[1]].val scale_factor_height = output_size[0] / x.shape[2] scale_factor_width = output_size[1] / x.shape[3] align_corners = context[node.inputs[2]].val x = mb.upsample_bilinear( x=x, scale_factor_height=scale_factor_height, scale_factor_width=scale_factor_width, align_corners=align_corners, name=node.name, ) context.add(x) class SAM2ImageEncoder(torch.nn.Module): def __init__(self, model: SAM2ImagePredictor): super().__init__() self.model = model @torch.no_grad() def forward(self, image): (img_embedding, feats_s0, feats_s1) = self.model.encode_image_raw(image) return img_embedding, feats_s0, feats_s1 def validate_image_encoder( model: ct.models.MLModel, ground_model: SAM2ImagePredictor, image: Image.Image ): prepared_image = image.resize(SAM2_HW, Resampling.BILINEAR) predictions = model.predict({"image": prepared_image}) image = np.array(image.convert("RGB")) tch_image = ground_model._transforms(image) tch_image = tch_image[None, ...].to("cpu") ground_embedding, ground_feats_s0, ground_feats_s1 = ground_model.encode_image_raw( tch_image ) ground_embedding, ground_feats_s0, ground_feats_s1 = ( ground_embedding.numpy(), ground_feats_s0.numpy(), ground_feats_s1.numpy(), ) img_max_diff = np.max(np.abs(predictions["image_embedding"] - ground_embedding)) img_avg_diff = np.mean(np.abs(predictions["image_embedding"] - ground_embedding)) s0_max_diff = np.max(np.abs(predictions["feats_s0"] - ground_feats_s0)) s0_avg_diff = np.mean(np.abs(predictions["feats_s0"] - ground_feats_s0)) s1_max_diff = np.max(np.abs(predictions["feats_s1"] - ground_feats_s1)) s1_avg_diff = np.mean(np.abs(predictions["feats_s1"] - ground_feats_s1)) print( f"Image Embedding: Max Diff: {img_max_diff:.4f}, Avg Diff: {img_avg_diff:.4f}" ) print(f"Feats S0: Max Diff: {s0_max_diff:.4f}, Avg Diff: {s0_avg_diff:.4f}") print(f"Feats S1: Max Diff: {s1_max_diff:.4f}, Avg Diff: {s1_avg_diff:.4f}") # Lack of bicubic upsampling in CoreML causes slight differences # assert np.allclose( # predictions["image_embedding"], ground_embedding, atol=2e1 # ) # assert np.allclose(predictions["feats_s0"], ground_feats_s0, atol=1e-1) # assert np.allclose(predictions["feats_s1"], ground_feats_s1, atol=1e-1) def validate_prompt_encoder( model: ct.models.MLModel, ground_model: SAM2ImagePredictor, unnorm_coords, labels ): predictions = model.predict({"points": unnorm_coords, "labels": labels}) (ground_sparse, ground_dense) = ground_model.encode_points_raw( unnorm_coords, labels ) ground_sparse = ground_sparse.numpy() ground_dense = ground_dense.numpy() sparse_max_diff = np.max(np.abs(predictions["sparse_embeddings"] - ground_sparse)) sparse_avg_diff = np.mean(np.abs(predictions["sparse_embeddings"] - ground_sparse)) dense_max_diff = np.max(np.abs(predictions["dense_embeddings"] - ground_dense)) dense_avg_diff = np.mean(np.abs(predictions["dense_embeddings"] - ground_dense)) print( "Sparse Embeddings: Max Diff: {:.4f}, Avg Diff: {:.4f}".format( sparse_max_diff, sparse_avg_diff ) ) print( "Dense Embeddings: Max Diff: {:.4f}, Avg Diff: {:.4f}".format( dense_max_diff, dense_avg_diff ) ) assert np.allclose(predictions["sparse_embeddings"], ground_sparse, atol=9e-3) assert np.allclose(predictions["dense_embeddings"], ground_dense, atol=1e-3) def validate_mask_decoder( model: ct.models.MLModel, ground_model: SAM2ImagePredictor, image_embedding, sparse_embedding, dense_embedding, feats_s0, feats_s1, precision: ComputePrecision, ): predictions = model.predict( { "image_embedding": image_embedding, "sparse_embedding": sparse_embedding, "dense_embedding": dense_embedding, "feats_s0": feats_s0, "feats_s1": feats_s1, } ) ground_masks, scores = ground_model.decode_masks_raw( image_embedding, sparse_embedding, dense_embedding, [feats_s0, feats_s1] ) ground_masks = ground_masks.numpy() masks_max_diff = np.max(np.abs(predictions["low_res_masks"] - ground_masks)) masks_avg_diff = np.mean(np.abs(predictions["low_res_masks"] - ground_masks)) print( "Masks: Max Diff: {:.4f}, Avg Diff: {:.4f}".format( masks_max_diff, masks_avg_diff ) ) # atol = 7e-2 if precision == ComputePrecision.FLOAT32 else 3e-1 # assert np.allclose(predictions["low_res_masks"], ground_masks, atol=atol) print(f"Scores: {predictions['scores']}, ground: {scores}") assert np.allclose(predictions["scores"], scores, atol=1e-2) class SAM2PointsEncoder(torch.nn.Module): def __init__(self, model: SAM2ImagePredictor): super().__init__() self.model = model @torch.no_grad() def forward(self, points, labels): prompt_embedding = self.model.encode_points_raw(points, labels) return prompt_embedding class SAM2MaskDecoder(torch.nn.Module): def __init__(self, model: SAM2ImagePredictor): super().__init__() self.model = model @torch.no_grad() def forward( self, image_embedding, sparse_embedding, dense_embedding, feats_s0, feats_s1 ): low_res_masks, iou_scores = self.model.decode_masks_raw( image_embedding, sparse_embedding, dense_embedding, [feats_s0, feats_s1] ) return low_res_masks, iou_scores def export_image_encoder( image_predictor: SAM2ImagePredictor, variant: SAM2Variant, output_dir: str, min_target: AvailableTarget, compute_units: ComputeUnit, precision: ComputePrecision, ) -> Tuple[int, int]: # Prepare input tensors image = Image.open("../notebooks/images/truck.jpg") image = np.array(image.convert("RGB")) orig_hw = (image.shape[0], image.shape[1]) prepared_image = image_predictor._transforms(image) prepared_image = prepared_image[None, ...].to("cpu") traced_model = torch.jit.trace( SAM2ImageEncoder(image_predictor).eval(), prepared_image ) scale = 1 / (0.226 * 255.0) bias = [-0.485 / (0.229), -0.456 / (0.224), -0.406 / (0.225)] mlmodel = ct.convert( traced_model, inputs=[ ct.ImageType( name="image", shape=(1, 3, SAM2_HW[0], SAM2_HW[1]), scale=scale, bias=bias, ) ], outputs=[ ct.TensorType(name="image_embedding"), ct.TensorType(name="feats_s0"), ct.TensorType(name="feats_s1"), ], minimum_deployment_target=min_target, compute_units=compute_units, compute_precision=precision, ) image = Image.open("../notebooks/images/truck.jpg") validate_image_encoder(mlmodel, image_predictor, image) output_path = os.path.join(output_dir, f"SAM2_1{variant.fmt()}ImageEncoder{precision.value.upper()}") mlmodel.save(output_path + ".mlpackage") return orig_hw def export_points_prompt_encoder( image_predictor: SAM2ImagePredictor, variant: SAM2Variant, input_points: List[List[float]], input_labels: List[int], orig_hw: tuple, output_dir: str, min_target: AvailableTarget, compute_units: ComputeUnit, precision: ComputePrecision, ): image_predictor.model.sam_prompt_encoder.eval() points = torch.tensor(input_points, dtype=torch.float32) labels = torch.tensor(input_labels, dtype=torch.int32) unnorm_coords = image_predictor._transforms.transform_coords( points, normalize=True, orig_hw=orig_hw, ) unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] traced_model = torch.jit.trace( SAM2PointsEncoder(image_predictor), (unnorm_coords, labels) ) points_shape = ct.Shape(shape=(1, ct.RangeDim(lower_bound=1, upper_bound=16), 2)) labels_shape = ct.Shape(shape=(1, ct.RangeDim(lower_bound=1, upper_bound=16))) mlmodel = ct.convert( traced_model, inputs=[ ct.TensorType(name="points", shape=points_shape), ct.TensorType(name="labels", shape=labels_shape), ], outputs=[ ct.TensorType(name="sparse_embeddings"), ct.TensorType(name="dense_embeddings"), ], minimum_deployment_target=min_target, compute_units=compute_units, compute_precision=precision, ) validate_prompt_encoder(mlmodel, image_predictor, unnorm_coords, labels) output_path = os.path.join(output_dir, f"SAM2{variant.fmt()}PromptEncoder{precision.value.upper()}") mlmodel.save(output_path + ".mlpackage") def export_mask_decoder( image_predictor: SAM2ImagePredictor, variant: SAM2Variant, output_dir: str, min_target: AvailableTarget, compute_units: ComputeUnit, precision: ComputePrecision, ): image_predictor.model.sam_mask_decoder.eval() s0 = torch.randn(1, 32, 256, 256) s1 = torch.randn(1, 64, 128, 128) image_embedding = torch.randn(1, 256, 64, 64) sparse_embedding = torch.randn(1, 3, 256) dense_embedding = torch.randn(1, 256, 64, 64) traced_model = torch.jit.trace( SAM2MaskDecoder(image_predictor), (image_embedding, sparse_embedding, dense_embedding, s0, s1), ) traced_model.eval() mlmodel = ct.convert( traced_model, inputs=[ ct.TensorType(name="image_embedding", shape=[1, 256, 64, 64]), ct.TensorType( name="sparse_embedding", shape=ct.EnumeratedShapes(shapes=[[1, i, 256] for i in range(2, 16)]), ), ct.TensorType(name="dense_embedding", shape=[1, 256, 64, 64]), ct.TensorType(name="feats_s0", shape=[1, 32, 256, 256]), ct.TensorType(name="feats_s1", shape=[1, 64, 128, 128]), ], outputs=[ ct.TensorType(name="low_res_masks"), ct.TensorType(name="scores"), ], minimum_deployment_target=min_target, compute_units=compute_units, compute_precision=precision, ) validate_mask_decoder( mlmodel, image_predictor, image_embedding, sparse_embedding, dense_embedding, s0, s1, precision, ) output_path = os.path.join(output_dir, f"SAM2{variant.fmt()}MaskDecoder{precision.value.upper()}") mlmodel.save(output_path + ".mlpackage") Point = Tuple[float, float] Box = Tuple[float, float, float, float] def export( output_dir: str, variant: SAM2Variant, points: Optional[List[Point]], boxes: Optional[List[Box]], labels: Optional[List[int]], min_target: AvailableTarget, compute_units: ComputeUnit, precision: ComputePrecision, ): os.makedirs(output_dir, exist_ok=True) device = torch.device("cpu") # Build SAM2 model sam2_checkpoint = f"facebook/sam2.1-hiera-{variant.value}" with torch.no_grad(): img_predictor = SAM2ImagePredictor.from_pretrained( sam2_checkpoint, device=device ) img_predictor.model.eval() orig_hw = export_image_encoder( img_predictor, variant, output_dir, min_target, compute_units, precision ) if boxes is not None and points is None: #if boxes is present and points is not, unique case raise ValueError("Boxes are not supported yet") else: export_points_prompt_encoder( img_predictor, variant, points, labels, orig_hw, output_dir, min_target, compute_units, precision, ) export_mask_decoder( img_predictor, variant, output_dir, min_target, compute_units, precision ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="SAM2 -> CoreML CLI") parser = parse_args(parser) args = parser.parse_args() points, boxes, labels = None, None, None if args.points: points = [tuple(p) for p in ast.literal_eval(args.points)] if args.boxes: boxes = [tuple(b) for b in ast.literal_eval(args.boxes)] if args.labels: labels = ast.literal_eval(args.labels) if boxes and points: raise ValueError("Cannot provide both points and boxes") if points: if not isinstance(points, list) or not all( isinstance(p, tuple) and len(p) == 2 for p in points ): raise ValueError("Points must be a tuple of 2D points") if labels: if not isinstance(labels, list) or not all( isinstance(l, int) and l in [0, 1] for l in labels ): raise ValueError("Labels must denote foreground (1) or background (0)") if points: if len(points) != len(labels): raise ValueError("Number of points must match the number of labels") if len(points) > 16: raise ValueError("Number of points must be less than or equal to 16") if boxes: if not isinstance(boxes, list) or not all( isinstance(b, tuple) and len(b) == 4 for b in boxes ): raise ValueError("Boxes must be a tuple of 4D bounding boxes") export( args.output_dir, args.variant, points, boxes, labels, args.min_deployment_target, args.compute_units, args.precision, )