evals/translators/nllb.py (42 lines of code) (raw):

import os import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import sys from mtdata import iso from tqdm import tqdm import toolz device = torch.device("cuda" if torch.cuda.is_available() else "cpu") LANG_CODE_MAP = { "ar": "arb_Arab", "fa": "pes_Arab", "lv": "lvs_Latn", "zh": "zho_Hans", } def translate(texts, tokenizer, model, target): results = [] if target in LANG_CODE_MAP: lang_code = LANG_CODE_MAP[target] else: lang_code = None for lang in tokenizer.additional_special_tokens: if lang.startswith(iso.iso3_code(target)): assert ( lang_code is None ), "Multiple NLLB language codes found for the same language ID, need to disambiguate!" lang_code = lang assert lang_code is not None, f"Lang code for {target} was not found" forced_bos_token_id = tokenizer.lang_code_to_id[lang_code] for partition in tqdm(list(toolz.partition_all(10, texts))): tokenized_src = tokenizer(partition, return_tensors="pt", padding=True).to(device) generated_tokens = model.generate(**tokenized_src, forced_bos_token_id=forced_bos_token_id) results += tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return results if __name__ == "__main__": texts = [line.strip() for line in sys.stdin] source = os.environ["SRC"] target = os.environ["TRG"] tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang=source) model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device) translations = translate(texts, tokenizer, model, target) sys.stdout.write("\n".join(translations)) sys.stdout.write("\n")