def _filter_db()

in dataset/co3d_dataset.py [0:0]


    def _filter_db(self):
        if self.remove_empty_masks:
            print("Removing images with empty masks.")
            old_len = len(self.frame_annots)
            self.frame_annots = [
                frame
                for frame in self.frame_annots
                if frame["frame_annotation"].mask is not None
                and frame["frame_annotation"].mask.mass > 1
            ]
            print("... filtered %d -> %d" % (old_len, len(self.frame_annots)))

        # this has to be called after joining with categories!!
        if self.subsets:
            if not self.subset_lists_file:
                raise ValueError(
                    "Subset filter is on but subset_lists_file was not given"
                )

            print(f"Limitting Co3D dataset to the '{self.subsets}' subsets.")
            # truncate the list of subsets to the valid one
            self.frame_annots = [
                entry for entry in self.frame_annots if entry["subset"] in self.subsets
            ]
            if len(self.frame_annots) == 0:
                raise ValueError(
                    f"There are no frames in the '{self.subsets}' subsets!"
                )

            self._invalidate_indexes(filter_seq_annots=True)

        if len(self.limit_category_to) > 0:
            print(f"Limitting dataset to categories: {self.limit_category_to}")
            self.seq_annots = {
                name: entry
                for name, entry in self.seq_annots.values()
                if entry.category in self.limit_category_to
            }

        # sequence filters
        for prefix in ("pick", "exclude"):
            orig_len = len(self.seq_annots)
            attr = f"{prefix}_sequence"
            arr = getattr(self, attr)
            if len(arr) > 0:
                print(f"{attr}: {str(arr)}")
                cond = lambda name, exclude=False: (name in arr) != exclude
                self.seq_annots = {
                    name: entry
                    for name, entry in self.seq_annots.items()
                    if cond(name, exclude=prefix == "exclude")
                }
                print("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))

        if self.limit_sequences_to > 0:
            self.seq_annots = dict(
                islice(self.seq_annots.items(), self.limit_sequences_to)
            )

        # retain only frames from retained sequences
        self.frame_annots = [
            f
            for f in self.frame_annots
            if f["frame_annotation"].sequence_name in self.seq_annots
        ]

        self._invalidate_indexes()

        if self.n_frames_per_sequence > 0:
            print(f"Taking max {self.n_frames_per_sequence} per sequence.")
            keep_idx = []
            for seq, seq_indices in self.seq_to_idx.items():
                # infer the seed from the sequence name, this is reproducible
                # and makes the selection differ for different sequences
                seed = _seq_name_to_seed(seq) + self.seed
                seq_idx_shuffled = random.Random(seed).sample(
                    sorted(seq_indices), len(seq_indices)
                )
                keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])

            print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)))
            self.frame_annots = [self.frame_annots[i] for i in keep_idx]
            self._invalidate_indexes(filter_seq_annots=False)
            # sequences are not decimated, so self.seq_annots is valid

        if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
            print(
                "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
            )
            self.frame_annots = self.frame_annots[: self.limit_to]
            self._invalidate_indexes(filter_seq_annots=True)