in training/model/sam2.py [0:0]
def __init__(
self,
image_encoder,
memory_attention=None,
memory_encoder=None,
prob_to_use_pt_input_for_train=0.0,
prob_to_use_pt_input_for_eval=0.0,
prob_to_use_box_input_for_train=0.0,
prob_to_use_box_input_for_eval=0.0,
# if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
num_frames_to_correct_for_train=1, # default: only iteratively sample on first frame
num_frames_to_correct_for_eval=1, # default: only iteratively sample on first frame
rand_frames_to_correct_for_train=False,
rand_frames_to_correct_for_eval=False,
# how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
# - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
# - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
# note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
# these are initial conditioning frames because as we track the video, more conditioning frames might be added
# when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
num_init_cond_frames_for_train=1, # default: only use the first frame as initial conditioning frame
num_init_cond_frames_for_eval=1, # default: only use the first frame as initial conditioning frame
rand_init_cond_frames_for_train=True, # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
rand_init_cond_frames_for_eval=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
# how many additional correction points to sample (on each frame selected to be corrected)
# note that the first frame receives an initial input click (in addition to any correction clicks)
num_correction_pt_per_frame=7,
# method for point sampling during evaluation
# "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
# default to "center" to be consistent with evaluation in the SAM paper
pt_sampling_for_eval="center",
# During training, we optionally allow sampling the correction points from GT regions
# instead of the prediction error regions with a small probability. This might allow the
# model to overfit less to the error regions in training datasets
prob_to_sample_from_gt_for_train=0.0,
use_act_ckpt_iterative_pt_sampling=False,
# whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
# of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
forward_backbone_per_frame_for_eval=False,
freeze_image_encoder=False,
**kwargs,