in utils/run_model.py [0:0]
def main() -> None:
if not os.environ.get("IS_DOCKER"):
# Re-run the command in docker if it wasn't started.
args = sys.argv[1:]
subprocess.check_call(["task", "docker-run", "--", "task", "run-model", "--", *args])
return
parser = argparse.ArgumentParser()
parser.add_argument(
"--task_id",
type=str,
help="The task ID that contains model artifacts",
)
parser.add_argument(
"--model",
type=str,
help="Path to a local folder containing a model",
)
parser.add_argument(
"--output",
type=str,
help="Where to save the models",
default=os.path.abspath(os.path.join(DATA_DIR, "models")),
)
parser.add_argument(
"--port", type=int, help="The port that the Marian server listens over", default=8886
)
parser.add_argument(
"--output_marian", default=False, help="Include the output from the Marian server."
)
args = parser.parse_args()
if not os.path.exists(args.output):
os.mkdir(args.output)
# Ensure the model is downloaded and we can get the model path. Re-use an existing model if
# it is alreaady downloaded.
model_path = None
if args.model:
model_path = args.model
elif args.task_id:
model_path = find_model(args.output, args.task_id)
if model_path:
print(f"Model with task ID {args.task_id} has been downloaded at {model_path}")
else:
model_path = download_model_from_task(args.output, args.task_id)
else:
raise Exception("Provide either a --task_id or a --model")
decoder = os.path.join(model_path, "decoder.yml")
model = os.path.join(model_path, "model.npz")
vocab = os.path.join(model_path, "vocab.spm")
if not os.path.exists(decoder):
raise Exception(f"Decoder was not found: {decoder}")
if not os.path.exists(model):
raise Exception(f"Model was not found: {model}")
if not os.path.exists(vocab):
raise Exception(f"Vocab was not found: {vocab}")
if args.output_marian:
stdout = None
stderr = None
else:
stdout = subprocess.DEVNULL
stderr = subprocess.DEVNULL
command = (
"/builds/worker/tools/marian-dev/build/marian-server "
f"--config {decoder} "
f"--models {model} "
f"--vocabs {vocab} {vocab} "
f"--port {args.port}"
)
marian_server = subprocess.Popen(command, shell=True, stdout=stdout, stderr=stderr)
atexit.register(exit_handler, marian_server)
translate_over_websocket(args.port)