tracking/translations_parser/wandb.py (150 lines of code) (raw):

import json import os from pathlib import Path from typing import List import wandb import taskcluster from translations_parser.parser import logger from translations_parser.publishers import WandB from translations_parser.utils import build_task_name, suffix_from_group def add_wandb_arguments(parser): parser.add_argument( "--wandb-project", help="Publish the training run to a Weight & Biases project.", default=None, ) parser.add_argument( "--wandb-artifacts", help="Directory containing training artifacts to publish on Weight & Biases.", type=Path, default=None, ) parser.add_argument( "--wandb-group", help="Add the training run to a Weight & Biases group e.g. by language pair or experiment.", default=None, ) parser.add_argument( "--wandb-run-name", help="Use a custom name for the Weight & Biases run.", default=None, ) parser.add_argument( "--wandb-publication", action="store_true", help="Trigger publication on Weight & Biases. Disabled by default. Can be set though env variable WANDB_PUBLICATION=true|false", default=os.environ.get("WANDB_PUBLICATION", "false").lower() == "true", ) parser.add_argument( "--taskcluster-secret", help="Taskcluster secret name used to store the Weight & Biases secret API Key.", type=str, default=os.environ.get("TASKCLUSTER_SECRET"), ) parser.add_argument( "--tags", help="List of tags to use on Weight & Biases publication", type=str, default=["taskcluster"], nargs="+", ) def get_wandb_token(secret_name): """ Retrieve the Weight & Biases token from Taskcluster secret """ secrets = taskcluster.Secrets({"rootUrl": os.environ["TASKCLUSTER_PROXY_URL"]}) try: wandb_secret = secrets.get(secret_name) return wandb_secret["secret"]["token"] except Exception as e: raise Exception( f"Weight & Biases secret API Key retrieved from Taskcluster is malformed: {e}" ) def get_wandb_names() -> tuple[str, str, str, str]: """ Find the various names needed to publish on Weight & Biases using the taskcluster task & group payloads. Returns project, group, run names and the task group ID. """ task_id = os.environ.get("TASK_ID") if not task_id: raise Exception("Weight & Biases name detection can only run in taskcluster") # Load task & group definition # CI task groups do not expose any configuration, so we must use default values queue = taskcluster.Queue({"rootUrl": os.environ["TASKCLUSTER_PROXY_URL"]}) task = queue.task(task_id) _, task_name = build_task_name(task) group_id = task["taskGroupId"] task_group = queue.task(group_id) config = task_group.get("extra", {}).get("action", {}).get("context", {}).get("input") if config is None: logger.warn( f"Experiment configuration missing on {group_id} @ extra/action/context/input, fallback to CI values" ) experiment = { "src": "ru", "trg": "en", "name": "ci", } else: experiment = config["experiment"] # Publish experiments triggered from the CI to a specific "ci" project if experiment["name"] == "ci": project = "ci" else: project = f'{experiment["src"]}-{experiment["trg"]}' return ( project, f'{experiment["name"]}_{group_id}', task_name, group_id, ) def get_wandb_publisher( project_name=None, group_name=None, run_name=None, taskcluster_secret=None, artifacts=[], tags=[], logs_file=None, publication=False, ): if not publication: logger.info( "Skip weight & biases publication as requested by operator through WANDB_PUBLICATION" ) return # Load secret from Taskcluster and auto-configure naming suffix = "" if taskcluster_secret: assert os.environ.get( "TASKCLUSTER_PROXY_URL" ), "When using `--taskcluster-secret`, `TASKCLUSTER_PROXY_URL` environment variable must be set too." # Weight and Biases client use environment variable to read the token os.environ.setdefault("WANDB_API_KEY", get_wandb_token(taskcluster_secret)) project_name, group_name, run_name, task_group_id = get_wandb_names() suffix = suffix_from_group(task_group_id) # Enable publication on weight and biases when project is set # But prevent running when explicitly disabled by operator if not project_name: logger.info("Skip weight & biases publication as project name is not set") return # Build optional configuration with log file config = {} if logs_file: config["logs_file"] = logs_file # Automatically adds experiment owner to the tags if author := os.environ.get("WANDB_AUTHOR"): tags.append(f"author:{author}") return WandB( project=project_name, group=group_name, name=run_name, suffix=suffix, artifacts=artifacts, tags=tags, config=config, ) def list_existing_group_logs_metrics( wandb_run: wandb.sdk.wandb_run.Run, ) -> List[List[str | float]]: """Retrieve the data from groups_logs metric table""" if wandb_run.resumed is False: return [] logger.info(f"Retrieving existing group logs metrics from group_logs ({wandb_run.id})") api = wandb.Api() run = api.run(f"{wandb_run.project}/{wandb_run.id}") last = next( ( artifact for artifact in list(run.files())[::-1] if artifact.name.startswith("media/table/metrics") ), None, ) if not last: return [] data = json.load(last.download(replace=True)) return data.get("data", [])