pipeline/bicleaner/download_pack.py (81 lines of code) (raw):
#!/usr/bin/env python3
"""
Downloads bicleaner-ai model for a lanuage pair.
Fallbacks to the multilingual model if the lanuage pair is not supported.
Example:
python download_pack.py \
--src=en \
--trg=ru \
artifacts/bicleaner-model-en-ru.zst
"""
import argparse
import os
import shutil
import subprocess
import tarfile
import tempfile
from typing import Optional
from pipeline.common.downloads import compress_file
from pipeline.common.logging import get_logger
logger = get_logger(__file__)
# bicleaner-ai-download downloads the latest models from Hugging Face / Github
# If a new model is released and you want to invalidate Taskcluster caches,
# change this file since it is a part of the cache digest
# The last model was added to https://huggingface.co/bitextor on Mar 11, 2024
def _run_download(src: str, trg: str, dir: str) -> subprocess.CompletedProcess:
# use large multilingual models
model_type = "full-large" if trg == "xx" else "full"
return subprocess.run(
["bicleaner-ai-download", src, trg, model_type, dir], capture_output=True, check=False
)
def _compress_dir(dir_path: str) -> str:
logger.info(f"Compressing {dir_path}")
tarball_path = f"{dir_path}.tar"
with tarfile.open(tarball_path, "w") as tar:
tar.add(dir_path, arcname=os.path.basename(dir_path))
compressed_path = str(compress_file(tarball_path))
return compressed_path
def check_result(result: subprocess.CompletedProcess):
"""Checks the return code, and outputs the stdout and stderr if it fails."""
if result.returncode != 0:
print(result.stdout)
print(result.stderr)
result.check_returncode()
def download(src: str, trg: str, output_path: str) -> None:
tmp_dir = os.path.join(tempfile.gettempdir(), f"bicleaner-ai-{src}-{trg}")
if os.path.exists(tmp_dir):
# A previous download attempt failed, remove the temporary files.
shutil.rmtree(tmp_dir)
os.mkdir(tmp_dir)
# Attempt to download a model.
# 1: src-trg
# 2: trg-src
# 3: multilingual model
logger.info(f"Attempt 1 of 3: Downloading a model for {src}-{trg}")
result = _run_download(src, trg, tmp_dir)
meta_path = os.path.join(tmp_dir, "metadata.yaml")
if os.path.exists(meta_path):
check_result(result)
logger.info(f"The model for {src}-{trg} is downloaded")
else:
src, trg = trg, src
logger.info(f"Attempt 2 of 3. Downloading a model for {src}-{trg}")
result = _run_download(src, trg, tmp_dir)
if os.path.exists(meta_path):
check_result(result)
print(f"The model for {src}-{trg} is downloaded")
else:
logger.info("Attempt 3 of 3. Downloading the multilingual model en-xx")
src = "en"
trg = "xx"
result = _run_download(src, trg, tmp_dir)
if not os.path.exists(meta_path):
check_result(result)
raise Exception("Could not download the multilingual model")
print(f"The model for {src}-{trg} is downloaded")
pack_path = tmp_dir
logger.info("Compress the downloaded pack.")
pack_path = _compress_dir(pack_path)
# Move to the expected path
logger.info(f"Moving {pack_path} to {output_path}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
shutil.move(pack_path, output_path)
logger.info("Done")
def main(args: Optional[list[str]] = None) -> None:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter, # Preserves whitespace in the help text.
)
parser.add_argument("--src", type=str, help="Source language code")
parser.add_argument("--trg", type=str, help="Target language code")
parser.add_argument(
"output_path",
type=str,
help="Full output file or directory path for example artifacts/en-pt.zst",
)
parsed_args = parser.parse_args(args)
download(
src=parsed_args.src,
trg=parsed_args.trg,
output_path=parsed_args.output_path,
)
if __name__ == "__main__":
main()