evals/translators/opusmt.py (23 lines of code) (raw):

import os import torch from transformers import MarianMTModel, MarianTokenizer import sys from tqdm import tqdm import toolz device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def translate(texts, tokenizer, model): results = [] for partition in tqdm(list(toolz.partition_all(10, texts))): tokenized_src = tokenizer(partition, return_tensors="pt", padding=True).to(device) generated_tokens = model.generate(**tokenized_src) results += tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return results if __name__ == "__main__": texts = [line.strip() for line in sys.stdin] source = os.environ["SRC"] target = os.environ["TRG"] tokenizer = MarianTokenizer.from_pretrained(f"Helsinki-NLP/opus-mt-{source}-{target}") model = MarianMTModel.from_pretrained(f"Helsinki-NLP/opus-mt-{source}-{target}").to(device) translations = translate(texts, tokenizer, model) sys.stdout.write("\n".join(translations)) sys.stdout.write("\n")