tracking/translations_parser/cli/taskcluster_group.py (238 lines of code) (raw):

#!/usr/bin/env python3 """ Track training experiments from a Taskcluster group and publish them to Weight and Biases. Example: track_tc_group --group-id=<group_id> """ import argparse import logging import tempfile from collections import defaultdict from pathlib import Path import wandb import taskcluster from taskcluster.download import downloadArtifactToBuf from translations_parser.data import Metric from translations_parser.parser import TrainingParser, logger from translations_parser.publishers import WandB from translations_parser.utils import ( MULTIPLE_TRAIN_SUFFIX, build_task_name, parse_task_label, publish_group_logs_from_tasks, suffix_from_group, ) KIND_TAG_TARGET = ("train", "finetune") queue = taskcluster.Queue({"rootUrl": "https://firefox-ci-tc.services.mozilla.com"}) def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Track training experiments from a Taskcluster group" ) parser.add_argument( "group_id", help="ID of the Taskcluster training task group.", ) parser.add_argument( "--no-recursive-lookup", help="Disable group traversal from provided group_id tasks dependencies.", action="store_true", ) parser.add_argument( "--override-runs", help="Override runs on Weight & Biases.", action="store_true", ) parser.add_argument( "--verbose", "-v", help="Print debug messages.", action="store_const", dest="loglevel", const=logging.DEBUG, ) return parser.parse_args() def get_logs(task: dict) -> list[str]: """Retrieve training logs from Taskcluster""" task_id = task["status"]["taskId"] logger.info(f"Downloading logs for task {task_id}") try: log, _ = downloadArtifactToBuf( taskId=task_id, name="public/build/train.log", queueService=queue, ) except Exception as e: logger.error(f"Could not retrieve logs: {e}") return [] return log.tobytes().decode().split("\n") def publish_task( *, project: str, group: str, name: str, suffix: str, task: dict, metrics: list[Metric] ) -> None: logs = get_logs(task) if not logs: logger.warning(f"Skipping publication of training task {name}") return parser = TrainingParser( logs, publishers=[ WandB( project=project, group=group, name=name, suffix=suffix, tags=["taskcluster-offline"], ) ], metrics=metrics, ) parser.run() def get_metrics_from_task(task: dict) -> list[Metric]: task_id = task["status"]["taskId"] logger.info(f"Retrieving artifacts from evaluation task {task_id}") metrics = [] for artifact in queue.listLatestArtifacts(task_id)["artifacts"]: if not artifact["name"].endswith(".metrics"): continue log, _ = downloadArtifactToBuf( taskId=task_id, name=artifact["name"], queueService=queue, ) tag = task["task"]["tags"]["label"] # Remove eventual slashes (e.g. <task_tag>-1/2) that cannot be written to the filesystem tag = MULTIPLE_TRAIN_SUFFIX.sub("", tag) with tempfile.TemporaryDirectory() as temp_dir: file = Path(temp_dir) / f"{tag}.txt" with file.open("wb") as log_file: log_file.write(log.tobytes()) log_file.flush() metrics.append(Metric.from_file(Path(log_file.name))) return metrics def filter_task(task: dict) -> tuple[str, dict] | tuple[None, None]: if task["status"]["state"] == "completed" and "vocab" not in task["task"]["tags"]["kind"]: try: prefix, task["name"] = build_task_name(task["task"]) except ValueError: # Task label may be unrelated to training or validation label = task["task"].get("tags", {}).get("label", "unknown") logger.debug(f"Skipping task with label {label}") else: return prefix, task return None, None def list_training_tasks(group_id: str, grouped_tasks: dict[str, list[dict]]) -> list[list[dict]]: training_tasks = sum( [tasks for key, tasks in grouped_tasks.items() if key in KIND_TAG_TARGET], start=[] ) if not training_tasks: logger.warning(f"No completed training task found for group {group_id}") else: logger.info(f"Found {len(training_tasks)} completed training tasks") return training_tasks def list_metrics_tasks(group_id: str, grouped_tasks: dict[str, list[dict]]) -> dict[str, dict]: metrics_tasks = {task["status"]["taskId"]: task for task in grouped_tasks["evaluate"]} if not metrics_tasks: logger.warning(f"No completed metrics task found for group {group_id}") else: logger.info(f"Found {len(metrics_tasks)} completed metrics tasks") return metrics_tasks def list_completed_tasks(group_id: str) -> dict[str, list[dict]]: logger.info(f"Listing completed tasks from group {group_id}") response = queue.listTaskGroup(group_id) tasks = response["tasks"] continuation_token = response.get("continuationToken") while continuation_token: # Results may be returned in multiple pages # https://docs.taskcluster.net/docs/reference/platform/queue/api#listTaskGroup response = queue.listTaskGroup(group_id, {"continuationToken": continuation_token}) tasks.extend(response["tasks"]) continuation_token = response.get("continuationToken") # Map tasks by categories grouped_tasks = defaultdict(list) for task in tasks: # Exclude non completed or vocab tasks prefix, filtered_task = filter_task(task) if filtered_task: grouped_tasks[prefix].append(filtered_task) return grouped_tasks def publish_task_group(group_id: str, override: bool = False) -> None: logger.info(f"Retrieving task group {group_id}") # Ensure task group is readable queue.getTaskGroup(group_id) # Read project and experiment name from task group configuration task_group = queue.task(group_id) config = task_group.get("extra", {}).get("action", {}).get("context", {}).get("input") # If the task group does not have a training configuration, we can skip its publication if config is None: logger.warning( f"Task group {group_id} cannot be published to WandB: " "configuration missing @ extra/action/context/input" ) return experiment = config["experiment"] project_name = f'{experiment["src"]}-{experiment["trg"]}' group_name = f'{experiment["name"]}_{group_id}' suffix = suffix_from_group(group_id) grouped_tasks = list_completed_tasks(group_id) training_tasks = list_training_tasks(group_id, grouped_tasks) metrics_tasks = list_metrics_tasks(group_id, grouped_tasks) if not training_tasks: logger.warning(f"Skipping task group {group_id} as it is empty") return logger.info(f"Processing group {group_name}") if override: existing_runs = list(wandb.Api().runs(project_name, filters={"group": group_name})) for run in existing_runs: logger.warning(f"Deleting existing run {run.display_name}.") run.delete() # Publish training tasks as runs for training_task in training_tasks: # Associate metrics to each runs (evaluate tasks that depends on the training task) dependent_tasks = [] for eval_id, eval_task in metrics_tasks.items(): eval_label = eval_task["task"]["tags"].get("label", "") try: model_name = parse_task_label(eval_label).model except ValueError: continue # Evaluation tasks must be a dependency of the run and match its name if ( training_task["status"]["taskId"] in eval_task["task"]["dependencies"] and model_name == training_task["name"] ): dependent_tasks.append(eval_id) metrics = sum( [ get_metrics_from_task(metrics_tasks.pop(dependent_task_id)) for dependent_task_id in dependent_tasks ], start=[], ) publish_task( project=project_name, group=group_name, suffix=suffix, name=training_task["name"], task=training_task, metrics=metrics, ) # Group and publish remaining metrics tasks via the logs publication publish_group_logs_from_tasks( project=project_name, group=group_name, suffix=suffix, metrics_tasks=metrics_tasks, config=config, ) def list_dependent_group_ids(task_id: str, known: set[str]): task = queue.task(task_id) # Browse task dependencies for dependent_task_id in task["dependencies"]: dependent_status = queue.status(dependent_task_id) group_id = dependent_status["status"]["taskGroupId"] if group_id in known: continue yield group_id known.add(group_id) # Shared instance of `known` to propagate discovered groups in real time across all recursion branches yield from list_dependent_group_ids(dependent_task_id, known) def main() -> None: args = get_args() if args.loglevel: logger.setLevel(args.loglevel) groups_ids = {args.group_id} if not args.no_recursive_lookup: logger.info(f"Retrieving related groups from {args.group_id} training tasks dependencies") completed_tasks = list_completed_tasks(args.group_id) training_tasks = list_training_tasks(args.group_id, completed_tasks) for training_task in training_tasks: dependent_ids = list_dependent_group_ids( training_task["status"]["taskId"], {*groups_ids} ) groups_ids.update(dependent_ids) logger.info( f"Found {len(groups_ids) - 1} additional groups to browse for WandB publication" ) else: logger.info( "--no-recursive-lookup option is set, only the provided group will be browsed for WandB publication" ) for group_id in groups_ids: publish_task_group(group_id, override=args.override_runs)