utils/run_model.py (175 lines of code) (raw):

#!/usr/bin/env python3 """ Run Marian models locally. When a --task_id is given, the model is downloaded to the ./data directory, and immediately run. This process loads the marian-server binary in the background, and communicates to it via a websocket. Usage: task run-model -- --task_id fpJkLJRaRAqTxgG0ARwR1w """ import argparse import atexit import os import re import subprocess import sys import time from typing import Any, Optional from websocket import WebSocket, create_connection import taskcluster from pipeline.common.downloads import stream_download_to_file CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) TC_MOZILLA = "https://firefox-ci-tc.services.mozilla.com" DATA_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../data")) def get_language_pair_from_task_name(task_name): pattern = r""" \btrain- # "train-" \w+- # "backwards", "student", etc. (?P<language_pair> # Start a capture group 'language_pair' \w{2}-\w{2} # Match the actual language pair ) """ match = re.search(pattern, task_name, re.VERBOSE) return match.group("language_pair") if match else None def download_model_from_task(output: str, task_id: str) -> str: """Downloads to the output directory in a subfolder {src}-{trg}-{task_id}""" options = {"rootUrl": ("%s" % TC_MOZILLA)} queue: Any = taskcluster.Queue(options=options) task = queue.task(task_id) status = queue.status(task_id) if status["status"]["state"] != "completed": raise Exception("The task was not completed") task_name = task["metadata"]["name"] language_pair = get_language_pair_from_task_name(task_name) if not language_pair: raise Exception(f'Could not find the language pair for the task "{task_name}"') print(f"Downloading from task {task_name}") artifacts = queue.listLatestArtifacts(task_id)["artifacts"] for artifact in artifacts: print(artifact["name"]) decoder = next( a for a in artifacts if a["name"].endswith("/final.model.npz.best-chrf.npz.decoder.yml") ) model = next(a for a in artifacts if a["name"].endswith("/final.model.npz.best-chrf.npz")) vocab = next(a for a in artifacts if a["name"].endswith("/vocab.spm")) if not decoder: raise Exception("Could not find the decoder in artifacts") if not model: raise Exception("Could not find the model in artifacts") if not vocab: raise Exception("Could not find the vocab in artifacts") model_path = os.path.join(output, f"{language_pair}-{task_id}") print(model_path) os.makedirs(model_path, exist_ok=True) print( f'Downloading models from "{task_name}": ' f"https://firefox-ci-tc.services.mozilla.com/tasks/{task_id}" ) downloads = [ (decoder, "decoder.yml"), (model, "model.npz"), (vocab, "vocab.spm"), ] for artifact, filename in downloads: stream_download_to_file( queue.buildUrl("getLatestArtifact", task_id, artifact["name"]), os.path.join(model_path, filename), ) print(f"Model files are available at: {model_path}") return model_path def find_model(output: str, task_id: str) -> Optional[str]: for dir_name in os.listdir(output): if f"-{task_id}" in dir_name: return os.path.join(output, dir_name) return None def connect_to_ws(port: int) -> WebSocket: """Attempts to connect to a websocket with multiple attempts.""" attempt = 0 ws = None max_retries = 100 retry_delay_sec = 1 uri = f"ws://localhost:{port}/translate" print(f"Attempting to connect to {uri}", end="") # Retry multiple times to connect. while attempt < max_retries: try: ws = create_connection(uri) break except Exception: # Attempt to reconnect print(".", end="") time.sleep(retry_delay_sec) attempt += 1 print() if ws is None: print("Failed to connect to the Marian server.") sys.exit(1) print("Connected to Marian server.\n") return ws def translate_over_websocket(port: int): """ Opens a websocket connection to the Marian server, and accepts translation input from stdin. """ ws = connect_to_ws(port) try: while True: print("Enter text to translate:") line = input("> ") ws.send(line.encode("utf-8")) translation = ws.recv() print("\nTranslation:") print(">", translation) except KeyboardInterrupt: pass except Exception as e: print(f"Error communicating with Marian server: {e}") ws.close() def main() -> None: if not os.environ.get("IS_DOCKER"): # Re-run the command in docker if it wasn't started. args = sys.argv[1:] subprocess.check_call(["task", "docker-run", "--", "task", "run-model", "--", *args]) return parser = argparse.ArgumentParser() parser.add_argument( "--task_id", type=str, help="The task ID that contains model artifacts", ) parser.add_argument( "--model", type=str, help="Path to a local folder containing a model", ) parser.add_argument( "--output", type=str, help="Where to save the models", default=os.path.abspath(os.path.join(DATA_DIR, "models")), ) parser.add_argument( "--port", type=int, help="The port that the Marian server listens over", default=8886 ) parser.add_argument( "--output_marian", default=False, help="Include the output from the Marian server." ) args = parser.parse_args() if not os.path.exists(args.output): os.mkdir(args.output) # Ensure the model is downloaded and we can get the model path. Re-use an existing model if # it is alreaady downloaded. model_path = None if args.model: model_path = args.model elif args.task_id: model_path = find_model(args.output, args.task_id) if model_path: print(f"Model with task ID {args.task_id} has been downloaded at {model_path}") else: model_path = download_model_from_task(args.output, args.task_id) else: raise Exception("Provide either a --task_id or a --model") decoder = os.path.join(model_path, "decoder.yml") model = os.path.join(model_path, "model.npz") vocab = os.path.join(model_path, "vocab.spm") if not os.path.exists(decoder): raise Exception(f"Decoder was not found: {decoder}") if not os.path.exists(model): raise Exception(f"Model was not found: {model}") if not os.path.exists(vocab): raise Exception(f"Vocab was not found: {vocab}") if args.output_marian: stdout = None stderr = None else: stdout = subprocess.DEVNULL stderr = subprocess.DEVNULL command = ( "/builds/worker/tools/marian-dev/build/marian-server " f"--config {decoder} " f"--models {model} " f"--vocabs {vocab} {vocab} " f"--port {args.port}" ) marian_server = subprocess.Popen(command, shell=True, stdout=stdout, stderr=stderr) atexit.register(exit_handler, marian_server) translate_over_websocket(args.port) def exit_handler(marian_server): marian_server.terminate() if __name__ == "__main__": main()