augmentation/filtering.py (27 lines of code) (raw):
from nlpaug import Augmenter
from typing import Callable, TypeVar
T = TypeVar("T")
class FilterAugmented:
def __init__(
self,
augmenter: Augmenter,
metric_fn: Callable[[str, str], T],
metric_acceptor: Callable[[T], bool]
):
self.augmenter = augmenter
self.metric_fn = metric_fn
self.metric_acceptor = metric_acceptor
def augment(self, text: str, *args, **kwargs):
augmented = self.augmenter.augment(text, *args, **kwargs)
if isinstance(augmented, str):
augmented = [augmented]
filter_fn = lambda variant: self.metric_acceptor(
self.metric_fn(text, variant)
)
return list(filter(filter_fn, augmented))
class ThresholdAcceptor:
def __init__(self, low, high):
self.low = low
self.high = high
def __call__(self, value):
return self.low < value < self.high