def filter()

in src/datatrove/pipeline/filters/fasttext_filter.py [0:0]


    def filter(self, doc: Document) -> bool:
        def check_label_scores(unit_scores):
            if self.keep_labels:
                return any(
                    unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.keep_labels
                )
            else:
                return not self.remove_labels or not any(
                    unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.remove_labels
                )

        units = split_into_parts(doc.text, mode=self.filter_mode)
        kept_spans = []
        label_scores = defaultdict(list)
        for unit in units:
            labels, scores = self.model.predict(unit.strip().replace("\n", self.newline_replacement), k=-1)
            if self.save_labels_in_metadata:
                for label, score in zip(labels, scores):
                    label_scores[label].append(score)
            if check_label_scores(dict(zip(labels, scores))):
                kept_spans.append(unit)
                self.stat_update("kept_span")
            else:
                self.stat_update("removed_span")
        doc.text = "".join(kept_spans)
        if self.save_labels_in_metadata:
            doc.metadata.update({label: np.mean(scores).item() for label, scores in label_scores.items()})
        return not not doc.text.strip()