augmentation/augmenters.py (49 lines of code) (raw):
from nlpaug.augmenter.word import BackTranslationAug
from utils import deepflatten_sequence
from typing import List
class MultiVariantBackTranslationAug(BackTranslationAug):
def __init__(self, from_model_name='transformer.wmt19.en-de', to_model_name='transformer.wmt19.de-en',
from_model_checkpt='model1.pt', to_model_checkpt='model1.pt', tokenizer='moses', bpe='fastbpe',
is_load_from_github=True, name='BackTranslationAug', device='cpu', force_reload=False, verbose=0,
n_predictions :int = 1, generation_kwargs = None):
super().__init__(from_model_name='transformer.wmt19.en-de', to_model_name='transformer.wmt19.de-en',
from_model_checkpt='model1.pt', to_model_checkpt='model1.pt', tokenizer='moses', bpe='fastbpe',
is_load_from_github=True, name='BackTranslationAug', device='cpu', force_reload=False, verbose=0)
self.n_predictions = n_predictions
self.generation_kwargs = generation_kwargs or {}
def substitute(self, text):
translated = self.sample_n(
self.model.from_model,
text,
beam=self.n_predictions,
n_samples=self.n_predictions,
**self.generation_kwargs
)
result = []
for t in translated:
nested = self.sample_n(
self.model.to_model,
t,
beam=self.n_predictions,
n_samples=self.n_predictions,
**self.generation_kwargs
) # List[List[str]]
backtranslated = list(set(
deepflatten_sequence(nested)
)) # List[str] with unique entries
result.append(backtranslated)
return result
@staticmethod
def sample_n(
model, sentences: List[str], beam: int = 1, verbose: bool = False, n_samples: int = 1, **kwargs
) -> List[List[str]]:
if isinstance(sentences, str):
return MultiVariantBackTranslationAug.sample_n(
model, [sentences], beam=beam, verbose=verbose, n_samples=n_samples, **kwargs
)
tokenized_sentences = [model.encode(sentence) for sentence in sentences]
batched_hypos = model.generate(tokenized_sentences, beam, verbose, **kwargs)
return [
[model.decode(hypos[i]["tokens"]) for i in range(min(n_samples, len(hypos)))]
for hypos in batched_hypos
]