in src/datatrove/pipeline/filters/fasttext_filter.py [0:0]
def filter(self, doc: Document) -> bool:
def check_label_scores(unit_scores):
if self.keep_labels:
return any(
unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.keep_labels
)
else:
return not self.remove_labels or not any(
unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.remove_labels
)
units = split_into_parts(doc.text, mode=self.filter_mode)
kept_spans = []
label_scores = defaultdict(list)
for unit in units:
labels, scores = self.model.predict(unit.strip().replace("\n", self.newline_replacement), k=-1)
if self.save_labels_in_metadata:
for label, score in zip(labels, scores):
label_scores[label].append(score)
if check_label_scores(dict(zip(labels, scores))):
kept_spans.append(unit)
self.stat_update("kept_span")
else:
self.stat_update("removed_span")
doc.text = "".join(kept_spans)
if self.save_labels_in_metadata:
doc.metadata.update({label: np.mean(scores).item() for label, scores in label_scores.items()})
return not not doc.text.strip()