pipeline/eval/eval.py (260 lines of code) (raw):
#!/usr/bin/env python3
"""
Evaluate a trained model with both the BLEU and chrF metrics.
Kinds:
taskcluster/kinds/evaluate/kind.yml
taskcluster/kinds/evaluate-quantized/kind.yml
taskcluster/kinds/evaluate-teacher-ensemble/kind.yml
Example usage:
$VCS_PATH/pipeline/eval/eval.py
--src en \\
--trg ca \\
--marian_config fetches/final.model.npz.best-chrf.npz.decoder.yml \\
--models fetches/final.model.npz.best-chrf.npz \\
--dataset_prefix fetches/wmt09 \\
--artifacts_prefix artifacts/wmt09 \\
--model_variant gpu \\
--workspace 12000 \\
--gpus 4
Artifacts:
For instance for a artifacts_prefix of: "artifacts/wmt09":
artifacts
├── wmt09.en The source sentences
├── wmt09.ca The target output
├── wmt09.ca.ref The original target sentences
├── wmt09.log The Marian log
├── wmt09.metrics The BLEU and chrF score
└── wmt09.metrics.json The BLEU and chrF score in json format
Fetches:
For instance for a value of: "fetches/wmt09":
fetches
├── wmt09.en.zst
└── wmt09.ca.zst
"""
import argparse
import json
import os
import subprocess
from textwrap import dedent, indent
from typing import Optional
from sacrebleu.metrics.bleu import BLEU, BLEUScore
from sacrebleu.metrics.chrf import CHRF, CHRFScore
from pipeline.common.downloads import decompress_file
from pipeline.common.logging import get_logger
logger = get_logger("eval")
try:
import wandb
from translations_parser.publishers import METRIC_KEYS, WandB
from translations_parser.utils import metric_from_tc_context
from translations_parser.wandb import (
add_wandb_arguments,
get_wandb_publisher,
list_existing_group_logs_metrics,
)
WANDB_AVAILABLE = "TASKCLUSTER_PROXY_URL" in os.environ
except ImportError as e:
print(f"Failed to import tracking module: {e}")
WANDB_AVAILABLE = False
def run_bash_oneliner(command: str):
"""
Runs multi-line bash with comments as a one-line command.
"""
command_dedented = dedent(command)
# Remove comments and whitespace.
lines = [
line.strip() for line in command_dedented.split("\n") if line and not line.startswith("#")
]
command = " \\\n".join(lines)
logger.info("-----------------Running bash in one line--------------")
logger.info(indent(command_dedented, " "))
logger.info("-------------------------------------------------------")
return subprocess.check_call(command, shell=True)
def main(args_list: Optional[list[str]] = None) -> None:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter, # Preserves whitespace in the help text.
)
parser.add_argument(
"--artifacts_prefix",
type=str,
help="The location where the translated results will be saved",
)
parser.add_argument(
"--dataset_prefix", type=str, help="The evaluation datasets prefix, used in the form."
)
parser.add_argument("--src", type=str, help='The source language, e.g "en".')
parser.add_argument("--trg", type=str, help='The target language, e.g "ca".')
parser.add_argument("--marian", type=str, help="The path the to marian binaries.")
parser.add_argument("--marian_config", type=str, help="The marian yaml config for the model.")
parser.add_argument(
"--quantized",
action="store_true",
help="Use a quantized model. This requires the browsermt fork of Marian",
)
parser.add_argument(
"--models",
type=str,
help="The Marian model (or models if its an ensemble) to use for translations",
)
parser.add_argument(
"--vocab_src",
required=False,
type=str,
help="The path to a src vocab file (optional)",
)
parser.add_argument(
"--vocab_trg",
required=False,
type=str,
help="The path to a trg vocab file (optional)",
)
parser.add_argument(
"--shortlist",
required=False,
type=str,
help="The path to a lexical shortlist (optional)",
)
parser.add_argument("--workspace", type=str, help="The preallocated MB for the workspace")
parser.add_argument(
"--gpus",
required=False,
type=str,
help="Which GPUs to use (only for the gpu model variant)",
)
parser.add_argument(
"--model_variant", type=str, help="The model variant to use, (gpu, cpu, quantized)"
)
# Add Weight & Biases CLI args when module is loaded
if WANDB_AVAILABLE:
add_wandb_arguments(parser)
args = parser.parse_args(args_list)
src = args.src
trg = args.trg
dataset_prefix = args.dataset_prefix
artifacts_prefix = args.artifacts_prefix
artifacts_dir = os.path.dirname(artifacts_prefix)
source_file_compressed = f"{dataset_prefix}.{src}.zst"
source_file = f"{artifacts_prefix}.{src}"
target_file_compressed = f"{dataset_prefix}.{trg}.zst"
target_file = f"{artifacts_prefix}.{trg}"
target_ref_file = f"{artifacts_prefix}.{trg}.ref"
marian_decoder = f'"{args.marian}"/marian-decoder'
marian_log_file = f"{artifacts_prefix}.log"
language_pair = f"{src}-{trg}"
metrics_file = f"{artifacts_prefix}.metrics"
metrics_json = f"{artifacts_prefix}.metrics.json"
# Configure Marian for the different model variants.
marian_extra_args = []
if args.model_variant == "quantized":
marian_extra_args = ["--int8shiftAlphaAll"]
elif args.model_variant == "gpu":
if not args.workspace:
raise Exception("The workspace size was not provided")
marian_extra_args = [
'--workspace', args.workspace,
'--devices', args.gpus,
] # fmt: skip
elif not args.model_variant == "cpu":
raise Exception(f"Unsupported model variant {args.model_variant}")
if args.vocab_src and args.vocab_trg:
marian_extra_args = [*marian_extra_args, "--vocabs", args.vocab_src, args.vocab_trg]
if args.shortlist:
# No arguments to the shortlist, so default ones are used
# this way it doesn't matter if the shortlist is binary or text
# because they have different arguments
# text shortlist args: firstNum bestNum threshold
# binary shortlist (has the arguments embedded) args: bool (check integrity)
marian_extra_args = marian_extra_args + ["--shortlist", args.shortlist]
logger.info("The eval script is configured with the following:")
logger.info(f" > artifacts_dir: {artifacts_dir}")
logger.info(f" > source_file_compressed: {source_file_compressed}")
logger.info(f" > source_file: {source_file}")
logger.info(f" > target_file: {target_file}")
logger.info(f" > target_ref_file: {target_ref_file}")
logger.info(f" > marian_decoder: {marian_decoder}")
logger.info(f" > marian_log_file: {marian_log_file}")
logger.info(f" > language_pair: {language_pair}")
logger.info(f" > metrics_file: {metrics_file}")
logger.info(f" > metrics_json: {metrics_json}")
logger.info(f" > marian_extra_args: {marian_extra_args}")
logger.info(f" > gpus: {args.gpus}")
logger.info("Ensure that the artifacts directory exists.")
os.makedirs(artifacts_dir, exist_ok=True)
logger.info("Save the original target sentences to the artifacts")
decompress_file(target_file_compressed, keep_original=False, decompressed_path=target_ref_file)
run_bash_oneliner(
f"""
# Decompress the source file, e.g. $fetches/wmt09.en.zst
zstdmt -dc "{source_file_compressed}"
# Tee the source file into the artifacts directory, e.g. $artifacts/wmt09.en
| tee "{source_file}"
# Take the source and pipe it in to be decoded (translated) by Marian.
| {marian_decoder}
--models {args.models}
--config {args.marian_config}
--quiet
--quiet-translation
--log {marian_log_file}
{" ".join(marian_extra_args)}
# The translations be "tee"ed out to the artifacts, e.g. $artifacts/wmt09.ca
| tee "{target_file}"
"""
)
with open(target_ref_file, "r") as file:
target_ref_lines = file.readlines()
with open(target_file, "r") as file:
target_lines = file.readlines()
with open(source_file, "r") as file:
source_lines = file.readlines()
compute_bleu = BLEU(trg_lang=trg)
compute_chrf = CHRF()
logger.info("Computing the BLEU score.")
bleu_score: BLEUScore = compute_bleu.corpus_score(target_lines, [target_ref_lines])
bleu_details = json.loads(
bleu_score.format(signature=compute_bleu.get_signature().format(), is_json=True)
)
logger.info("Computing the chrF score.")
chrf_score: CHRFScore = compute_chrf.corpus_score(target_lines, [target_ref_lines])
chrf_details = json.loads(
chrf_score.format(signature=compute_chrf.get_signature().format(), is_json=True)
)
# The default comet model.
# It should match the model used in https://github.com/mozilla/firefox-translations-models/
comet_model_name = "Unbabel/wmt22-comet-da"
if os.environ.get("COMET_SKIP"):
comet_score = "skipped"
print("COMET_SKIP was set, so the COMET score will not be computed.")
else:
logger.info("Loading COMET")
import comet
# COMET_MODEL_DIR allows tests to place the model in a data directory
comet_checkpoint = comet.download_model(
comet_model_name, saving_directory=os.environ.get("COMET_MODEL_DIR")
)
comet_model = comet.load_from_checkpoint(comet_checkpoint)
comet_data = []
for source, target, target_ref in zip(source_lines, target_lines, target_ref_lines):
comet_data.append({"src": source, "mt": target, "ref": target_ref})
# GPU information comes in the form of a list of numbers, e.g. "0 1 2 3". Split these to
# get the GPU count.
gpu_count = len(args.gpus.split(" "))
if os.environ.get("COMET_CPU"):
gpu_count = 0 # Let tests override the CPU count.
comet_mode = "cpu" if gpu_count == 0 else "gpu"
logger.info(f'Computing the COMET score with "{comet_model_name}" using the {comet_mode}')
comet_results = comet_model.predict(comet_data, gpus=gpu_count)
# Reduce the precision.
comet_score = round(comet_results.system_score, 4)
metrics = {
"bleu": {
"score": bleu_details["score"],
# Example details:
# {
# "name": "BLEU",
# "score": 0.4,
# "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0",
# "verbose_score": "15.6/0.3/0.2/0.1 (BP = 0.823 ratio = 0.837 hyp_len = 180 ref_len = 215)",
# "nrefs": "1",
# "case": "mixed",
# "eff": "no",
# "tok": "13a",
# "smooth": "exp",
# "version": "2.0.0"
# }
"details": bleu_details,
},
"chrf": {
"score": chrf_details["score"],
# Example details:
# {
# "name": "chrF2",
# "score": 0.64,
# "signature": "nrefs:1|case:mixed|eff:yes|nc:6|nw:0|space:no|version:2.0.0",
# "nrefs": "1",
# "case": "mixed",
# "eff": "yes",
# "nc": "6",
# "nw": "0",
# "space": "no",
# "version": "2.0.0"
# }
"details": chrf_details,
},
"comet": {
"score": comet_score,
"details": {
"model": comet_model_name,
"score": comet_score,
},
},
}
logger.info(f"Writing {metrics_json}")
with open(metrics_json, "w") as file:
file.write(json.dumps(metrics, indent=2))
logger.info(f'Writing the metrics in the older "text" format: {metrics_file}')
with open(metrics_file, "w") as file:
file.write(f"{bleu_details['score']}\n" f"{chrf_details['score']}\n" f"{comet_score}\n")
if WANDB_AVAILABLE:
metric = metric_from_tc_context(
chrf=chrf_details["score"], bleu=bleu_details["score"], comet=comet_score
)
run_client = get_wandb_publisher( # noqa
project_name=args.wandb_project,
group_name=args.wandb_group,
run_name=args.wandb_run_name,
taskcluster_secret=args.taskcluster_secret,
artifacts=args.wandb_artifacts,
publication=args.wandb_publication,
)
if run_client is None:
# W&B publication may be direclty disabled through WANDB_PUBLICATION
return
logger.info(f"Publishing metrics to Weight & Biases ({run_client.extra_kwargs})")
run_client.open()
run_client.handle_metrics(metrics=[metric])
run_client.close()
# Publish an extra row on the group_logs summary run
group_logs_client = WandB( # noqa
project=run_client.wandb.project,
group=run_client.wandb.group,
name="group_logs",
suffix=run_client.suffix,
)
logger.info("Adding metric row to the 'group_logs' run")
group_logs_client.open()
# Restore existing metrics data
data = list_existing_group_logs_metrics(group_logs_client.wandb)
data.append(
[
run_client.wandb.group,
run_client.wandb.name,
metric.importer,
metric.dataset,
metric.augmentation,
]
+ [getattr(metric, attr) for attr in METRIC_KEYS]
)
group_logs_client.wandb.log(
{
"metrics": wandb.Table(
columns=[
"Group",
"Model",
"Importer",
"Dataset",
"Augmenation",
*METRIC_KEYS,
],
data=data,
)
}
)
group_logs_client.close()
if __name__ == "__main__":
main()