in augmentation/backtranslate.py [0:0]
def build_chain_translation_augmenter(language_chain: List[str], device: str) -> Sequential:
pair_to_model = {
"en-fr": "transformer.wmt14.en-fr",
"en-de": "transformer.wmt19.en-de",
"de-en": "transformer.wmt19.de-en",
"en-ru": "transformer.wmt19.en-ru",
"ru-en": "transformer.wmt19.ru-en"
}
if len(language_chain) <= 2:
raise Exception("Can't backtranslate with less than two languages in a chain")
augmenters = []
for i in range(len(language_chain) - 2):
from_key = f"{language_chain[i]}-{language_chain[i+1]}"
to_key = f"{language_chain[i+1]}-{language_chain[i+2]}"
from_model_name = pair_to_model[from_key]
to_model_name = pair_to_model[to_key]
augmenters.append(
BackTranslationAug(from_model_name=from_model_name,
to_model_name=to_model_name,
device=device)
)
return Sequential(augmenters)