in sam2/sam2_video_predictor.py [0:0]
def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking."""
# Tracking has started and we don't allow adding new objects until session is reset.
inference_state["tracking_has_started"] = True
batch_size = self._get_obj_num(inference_state)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
output_dict = inference_state["output_dict"]
# "consolidated_frame_inds" contains indices of those frames where consolidated
# temporary outputs have been added (either in this call or any previous calls
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points_or_box` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
)
# merge them into "output_dict" and also create per-object slices
output_dict[storage_key][frame_idx] = consolidated_out
self._add_output_per_object(
inference_state, frame_idx, consolidated_out, storage_key
)
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
if clear_non_cond_mem:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
# clear temporary outputs in `temp_output_dict_per_obj`
for obj_temp_output_dict in temp_output_dict_per_obj.values():
obj_temp_output_dict[storage_key].clear()
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in output_dict["cond_frame_outputs"]:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
assert frame_idx in output_dict["cond_frame_outputs"]
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
# with either points or mask inputs (which should be true under a correct workflow).
all_consolidated_frame_inds = (
consolidated_frame_inds["cond_frame_outputs"]
| consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
assert all_consolidated_frame_inds == input_frames_inds