in training/model/sam2.py [0:0]
def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
"""
Prepare input mask, point or box prompts. Optionally, we allow tracking from
a custom `start_frame_idx` to the end of the video (for evaluation purposes).
"""
# Load the ground-truth masks on all frames (so that we can later
# sample correction points from them)
# gt_masks_per_frame = {
# stage_id: targets.segments.unsqueeze(1) # [B, 1, H_im, W_im]
# for stage_id, targets in enumerate(input.find_targets)
# }
gt_masks_per_frame = {
stage_id: masks.unsqueeze(1) # [B, 1, H_im, W_im]
for stage_id, masks in enumerate(input.masks)
}
# gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
num_frames = input.num_frames
backbone_out["num_frames"] = num_frames
# Randomly decide whether to use point inputs or mask inputs
if self.training:
prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
prob_to_use_box_input = self.prob_to_use_box_input_for_train
num_frames_to_correct = self.num_frames_to_correct_for_train
rand_frames_to_correct = self.rand_frames_to_correct_for_train
num_init_cond_frames = self.num_init_cond_frames_for_train
rand_init_cond_frames = self.rand_init_cond_frames_for_train
else:
prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
prob_to_use_box_input = self.prob_to_use_box_input_for_eval
num_frames_to_correct = self.num_frames_to_correct_for_eval
rand_frames_to_correct = self.rand_frames_to_correct_for_eval
num_init_cond_frames = self.num_init_cond_frames_for_eval
rand_init_cond_frames = self.rand_init_cond_frames_for_eval
if num_frames == 1:
# here we handle a special case for mixing video + SAM on image training,
# where we force using point input for the SAM task on static images
prob_to_use_pt_input = 1.0
num_frames_to_correct = 1
num_init_cond_frames = 1
assert num_init_cond_frames >= 1
# (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)
use_pt_input = self.rng.random() < prob_to_use_pt_input
if rand_init_cond_frames and num_init_cond_frames > 1:
# randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
num_init_cond_frames = self.rng.integers(
1, num_init_cond_frames, endpoint=True
)
if (
use_pt_input
and rand_frames_to_correct
and num_frames_to_correct > num_init_cond_frames
):
# randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
# correction clicks (only for the case of point input)
num_frames_to_correct = self.rng.integers(
num_init_cond_frames, num_frames_to_correct, endpoint=True
)
backbone_out["use_pt_input"] = use_pt_input
# Sample initial conditioning frames
if num_init_cond_frames == 1:
init_cond_frames = [start_frame_idx] # starting frame
else:
# starting frame + randomly selected remaining frames (without replacement)
init_cond_frames = [start_frame_idx] + self.rng.choice(
range(start_frame_idx + 1, num_frames),
num_init_cond_frames - 1,
replace=False,
).tolist()
backbone_out["init_cond_frames"] = init_cond_frames
backbone_out["frames_not_in_init_cond"] = [
t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
]
# Prepare mask or point inputs on initial conditioning frames
backbone_out["mask_inputs_per_frame"] = {} # {frame_idx: <input_masks>}
backbone_out["point_inputs_per_frame"] = {} # {frame_idx: <input_points>}
for t in init_cond_frames:
if not use_pt_input:
backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
else:
# During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
use_box_input = self.rng.random() < prob_to_use_box_input
if use_box_input:
points, labels = sample_box_points(
gt_masks_per_frame[t],
)
else:
# (here we only sample **one initial point** on initial conditioning frames from the
# ground-truth mask; we may sample more correction points on the fly)
points, labels = get_next_point(
gt_masks=gt_masks_per_frame[t],
pred_masks=None,
method=(
"uniform" if self.training else self.pt_sampling_for_eval
),
)
point_inputs = {"point_coords": points, "point_labels": labels}
backbone_out["point_inputs_per_frame"][t] = point_inputs
# Sample frames where we will add correction clicks on the fly
# based on the error between prediction and ground-truth masks
if not use_pt_input:
# no correction points will be sampled when using mask inputs
frames_to_add_correction_pt = []
elif num_frames_to_correct == num_init_cond_frames:
frames_to_add_correction_pt = init_cond_frames
else:
assert num_frames_to_correct > num_init_cond_frames
# initial cond frame + randomly selected remaining frames (without replacement)
extra_num = num_frames_to_correct - num_init_cond_frames
frames_to_add_correction_pt = (
init_cond_frames
+ self.rng.choice(
backbone_out["frames_not_in_init_cond"], extra_num, replace=False
).tolist()
)
backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt
return backbone_out