in training/dataset/vos_segment_loader.py [0:0]
def load(self, frame_id, obj_ids=None):
assert frame_id % self.ann_every == 0
rle_mask = self.frame_annots[frame_id // self.ann_every]
valid_objs_ids = set(range(len(rle_mask)))
if self.valid_obj_ids is not None:
# Remove the masklets that have been filtered out for this video
valid_objs_ids &= set(self.valid_obj_ids)
if obj_ids is not None:
# Only keep the objects that have been sampled
valid_objs_ids &= set(obj_ids)
valid_objs_ids = sorted(list(valid_objs_ids))
# Construct rle_masks_filtered that only contains the rle masks we are interested in
id_2_idx = {}
rle_mask_filtered = []
for obj_id in valid_objs_ids:
if rle_mask[obj_id] is not None:
id_2_idx[obj_id] = len(rle_mask_filtered)
rle_mask_filtered.append(rle_mask[obj_id])
else:
id_2_idx[obj_id] = None
# Decode the masks
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute(
2, 0, 1
) # (num_obj, h, w)
segments = {}
for obj_id in valid_objs_ids:
if id_2_idx[obj_id] is None:
segments[obj_id] = None
else:
idx = id_2_idx[obj_id]
segments[obj_id] = raw_segments[idx]
return segments