in dataset/co3d_dataset.py [0:0]
def _filter_db(self):
if self.remove_empty_masks:
print("Removing images with empty masks.")
old_len = len(self.frame_annots)
self.frame_annots = [
frame
for frame in self.frame_annots
if frame["frame_annotation"].mask is not None
and frame["frame_annotation"].mask.mass > 1
]
print("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
# this has to be called after joining with categories!!
if self.subsets:
if not self.subset_lists_file:
raise ValueError(
"Subset filter is on but subset_lists_file was not given"
)
print(f"Limitting Co3D dataset to the '{self.subsets}' subsets.")
# truncate the list of subsets to the valid one
self.frame_annots = [
entry for entry in self.frame_annots if entry["subset"] in self.subsets
]
if len(self.frame_annots) == 0:
raise ValueError(
f"There are no frames in the '{self.subsets}' subsets!"
)
self._invalidate_indexes(filter_seq_annots=True)
if len(self.limit_category_to) > 0:
print(f"Limitting dataset to categories: {self.limit_category_to}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.values()
if entry.category in self.limit_category_to
}
# sequence filters
for prefix in ("pick", "exclude"):
orig_len = len(self.seq_annots)
attr = f"{prefix}_sequence"
arr = getattr(self, attr)
if len(arr) > 0:
print(f"{attr}: {str(arr)}")
cond = lambda name, exclude=False: (name in arr) != exclude
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if cond(name, exclude=prefix == "exclude")
}
print("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
if self.limit_sequences_to > 0:
self.seq_annots = dict(
islice(self.seq_annots.items(), self.limit_sequences_to)
)
# retain only frames from retained sequences
self.frame_annots = [
f
for f in self.frame_annots
if f["frame_annotation"].sequence_name in self.seq_annots
]
self._invalidate_indexes()
if self.n_frames_per_sequence > 0:
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
for seq, seq_indices in self.seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
seed = _seq_name_to_seed(seq) + self.seed
seq_idx_shuffled = random.Random(seed).sample(
sorted(seq_indices), len(seq_indices)
)
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)))
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
self._invalidate_indexes(filter_seq_annots=False)
# sequences are not decimated, so self.seq_annots is valid
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
print(
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
)
self.frame_annots = self.frame_annots[: self.limit_to]
self._invalidate_indexes(filter_seq_annots=True)