def build_chain_translation_augmenter()

in augmentation/backtranslate.py [0:0]


def build_chain_translation_augmenter(language_chain: List[str], device: str) -> Sequential:
    pair_to_model = {
        "en-fr": "transformer.wmt14.en-fr",
        "en-de": "transformer.wmt19.en-de",
        "de-en": "transformer.wmt19.de-en",
        "en-ru": "transformer.wmt19.en-ru",
        "ru-en": "transformer.wmt19.ru-en"
    }
    if len(language_chain) <= 2:
        raise Exception("Can't backtranslate with less than two languages in a chain")

    augmenters = []
    for i in range(len(language_chain) - 2):
        from_key = f"{language_chain[i]}-{language_chain[i+1]}"
        to_key = f"{language_chain[i+1]}-{language_chain[i+2]}"
        from_model_name = pair_to_model[from_key]
        to_model_name = pair_to_model[to_key]
        augmenters.append(
            BackTranslationAug(from_model_name=from_model_name,
                               to_model_name=to_model_name,
                               device=device)
        )
    return Sequential(augmenters)