in sam2/modeling/sam2_base.py [0:0]
def _build_sam_heads(self):
"""Build SAM-style prompt encoder and mask decoder."""
self.sam_prompt_embed_dim = self.hidden_dim
self.sam_image_embedding_size = self.image_size // self.backbone_stride
# build PromptEncoder and MaskDecoder from SAM
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
self.sam_prompt_encoder = PromptEncoder(
embed_dim=self.sam_prompt_embed_dim,
image_embedding_size=(
self.sam_image_embedding_size,
self.sam_image_embedding_size,
),
input_image_size=(self.image_size, self.image_size),
mask_in_chans=16,
)
self.sam_mask_decoder = MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=self.sam_prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.sam_prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
use_high_res_features=self.use_high_res_features_in_sam,
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
pred_obj_scores=self.pred_obj_scores,
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
**(self.sam_mask_decoder_extra_args or {}),
)
if self.use_obj_ptrs_in_encoder:
# a linear projection on SAM output tokens to turn them into object pointers
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
if self.use_mlp_for_obj_ptr_proj:
self.obj_ptr_proj = MLP(
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
)
else:
self.obj_ptr_proj = torch.nn.Identity()
if self.proj_tpos_enc_in_obj_ptrs:
# a linear projection on temporal positional encoding in object pointers to
# avoid potential interference with spatial positional encoding
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
else:
self.obj_ptr_tpos_proj = torch.nn.Identity()