tracking/translations_parser/publishers.py (311 lines of code) (raw):

import csv import logging import sys from abc import ABC from collections import defaultdict from pathlib import Path from typing import Sequence import wandb import yaml from translations_parser.data import Metric, TrainingEpoch, ValidationEpoch from translations_parser.utils import parse_task_label, parse_gcp_metric, patch_model_name logger = logging.getLogger(__name__) METRIC_KEYS = sorted(set(Metric.__annotations__.keys()) - {"importer", "dataset", "augmentation"}) class Publisher(ABC): """ Abstract class used to publish parsed data. Either the `handle_*` methods can be overriden for real time publication (introduced later on) or the `publish` method with all results (including parser run date, configuration…). """ def open(self, parser) -> None: ... def handle_training(self, training: TrainingEpoch) -> None: ... def handle_validation(self, validation: ValidationEpoch) -> None: ... def handle_metrics(self, metrics: Sequence[Metric]) -> None: ... def publish(self) -> None: ... def close(self) -> None: ... class CSVExport(Publisher): def __init__(self, output_dir: Path) -> None: from translations_parser.parser import TrainingParser if not output_dir.is_dir(): raise ValueError("Output must be a valid directory for the CSV export") self.output_dir = output_dir self.parser: TrainingParser | None = None def open(self, parser=None) -> None: self.parser = parser def write_data( self, output: Path, entries: Sequence[TrainingEpoch | ValidationEpoch], dataclass: type ) -> None: if not entries: logger.warning(f"No {dataclass.__name__} entry, skipping.") with open(output, "w") as f: writer = csv.DictWriter(f, fieldnames=dataclass.__annotations__) writer.writeheader() for entry in entries: writer.writerow(vars(entry)) def publish(self) -> None: assert self.parser is not None, "Parser must be set to run CSV publication." training_log = self.parser.output training_output = self.output_dir / "training.csv" if training_output.exists(): logger.warning(f"Training output file {training_output} exists, skipping.") else: self.write_data(training_output, training_log.training, TrainingEpoch) validation_output = self.output_dir / "validation.csv" if validation_output.exists(): logger.warning(f"Validation output file {validation_output} exists, skipping.") else: self.write_data(validation_output, training_log.validation, ValidationEpoch) class WandB(Publisher): def __init__( self, *, project: str, group: str, name: str, suffix: str = "", # Optional path to a directory containing training artifacts artifacts: Path | None = None, artifacts_name: str = "logs", **extra_kwargs, ): from translations_parser.parser import TrainingParser # Set logging of wandb module to WARNING, so we output training logs instead self.wandb_logger = logging.getLogger("wandb") self.wandb_logger.setLevel(logging.ERROR) self.project = project self.group = group self.suffix = suffix # Build a unique run identifier based on the passed suffix # This ID is also used as display name on W&B, as the interface expects unique display names among runs self.run = f"{name}{suffix}" self.artifacts = artifacts self.artifacts_name = artifacts_name self.extra_kwargs = extra_kwargs self.parser: TrainingParser | None = None self.wandb: wandb.sdk.wandb_run.Run | wandb.sdk.lib.disabled.RunDisabled | None = None def close(self) -> None: if self.wandb is None: return # Publish artifacts if self.artifacts: artifact = wandb.Artifact(name=self.artifacts_name, type=self.artifacts_name) artifact.add_dir(local_path=str(self.artifacts.resolve())) self.wandb.log_artifact(artifact) if self.parser is not None: # Store Marian logs as the main log artifact, instead of W&B client runtime. # This will be overwritten in case an unhandled exception occurs. for line in self.parser.parsed_logs: sys.stdout.write(f"{line}\n") self.wandb.finish() def open(self, parser=None) -> None: self.parser = parser config = getattr(parser, "config", {}).copy() config.update(self.extra_kwargs.pop("config", {})) # Publish datasets stats directly in the dashboard datasets = config.pop("datasets", None) try: self.wandb = wandb.init( project=self.project, group=self.group, name=self.run, id=self.run, config=config, # Since we use unique run names based on group ID (e.g. finetune-student_MjcJG), # we can use "allow" mode for resuming a stopped Taskcluster run in case of preemption. # It will continue logging to the same run if it exists. # Offline publication should handle run deletion separately (use --override-runs). resume="allow", **self.extra_kwargs, ) if self.wandb.resumed: logger.info(f"W&B run is being resumed from existing run '{self.run}'.") except Exception as e: logger.error(f"WandB client could not be initialized: {e}. No data will be published.") if datasets is not None: # Log dataset sizes as a custom bar chart self.wandb.log( { "Datasets": wandb.plot.bar( wandb.Table( columns=["Name", "Count"], data=[[key, value] for key, value in datasets.items()], ), "Name", "Count", title="Datasets", ) } ) def generic_log(self, data: TrainingEpoch | ValidationEpoch) -> None: if self.wandb is None: return epoch = vars(data) step = epoch.pop("up") for key, val in epoch.items(): if val is None: # Do not publish null values (e.g. perplexity in Marian 1.10) continue self.wandb.log(step=step, data={key: val}) def handle_training(self, training: TrainingEpoch) -> None: self.generic_log(training) def handle_validation(self, validation: ValidationEpoch) -> None: self.generic_log(validation) def handle_metrics(self, metrics: Sequence[Metric]) -> None: if self.wandb is None: return for metric in metrics: title = metric.importer if metric.augmentation: title = f"{title}_{metric.augmentation}" if metric.dataset: title = f"{title}_{metric.dataset}" # Publish a bar chart (a table with values will also be available from W&B) self.wandb.log( { title: wandb.plot.bar( wandb.Table( columns=["Metric", "Value"], data=[ [key, getattr(metric, key)] for key in METRIC_KEYS if getattr(metric, key) is not None ], ), "Metric", "Value", title=title, ) } ) @classmethod def publish_group_logs( cls, *, logs_parent_folder: list[str], project: str, group: str, suffix: str, existing_runs: list[str] | None = None, snakemake: bool = False, ) -> None: """ Publish files within `logs_dir` to W&B artifacts for a specific group. A fake W&B run named `group_logs` is created to publish those artifacts among with all evaluation files (quantized + experiments). If existing run is set, runs found not specified in this list will also be published to W&B. """ from translations_parser.parser import TrainingParser try: if ( len( wandb.Api().runs( path=project, filters={"display_name": "group_logs", "group": group} ) ) > 0 ): logger.warning("Skipping group_logs fake run publication as it already exists") return except ValueError as e: # Project may not exist yet as group_logs is published before the first training task if "could not find project" not in str(e).lower(): logger.warning(f"Detection of a previous group_logs run failed: {e}") logs_dir = Path("/".join([*logs_parent_folder[:-1], "logs", project, group])) models_dir = Path("/".join([*logs_parent_folder[:-1], "models", project, group])) # Old experiments use `speed` directory for quantized metrics quantized_metrics = sorted( Path( "/".join( [*logs_parent_folder[:-1], "models", project, group, "evaluation", "speed"] ) ).glob("*.metrics") ) logs_metrics = sorted((logs_dir / "eval").glob("eval*.log")) direct_metrics = sorted((logs_dir / "metrics").glob("*.metrics")) taskcluster_metrics = [] # Do not retrieve metrics from models directory for legacy Snakemake experiments if snakemake is False: taskcluster_metrics = sorted((models_dir).glob("**/*.metrics")) if quantized_metrics: logger.info(f"Found {len(quantized_metrics)} quantized metrics from speed folder") if logs_metrics: logger.info(f"Found {len(logs_metrics)} metrics from task logs") if direct_metrics: logger.info(f"Found {len(direct_metrics)} Snakemake metrics from .metrics artifacts") if taskcluster_metrics: logger.info( f"Found {len(taskcluster_metrics)} Taskcluster metrics from .metrics artifacts" ) # Store metrics by run name metrics = defaultdict(list) # Add metrics from the speed folder for file in quantized_metrics: importer, dataset = file.stem.split("_", 1) metrics["quantized"].append(Metric.from_file(file, importer=importer, dataset=dataset)) # Add metrics from tasks logs for file in logs_metrics: try: model_name, importer, dataset, aug = parse_task_label(file.stem) with file.open("r") as f: lines = f.readlines() metrics[model_name].append( Metric.from_tc_context( importer=importer, dataset=dataset, lines=lines, augmentation=aug ) ) except ValueError as e: logger.error(f"Could not parse metrics from {file.resolve()}: {e}") # Add metrics from old SnakeMake .metrics files for file in direct_metrics: model_name, importer, dataset, aug = parse_task_label(file.stem) try: metrics[model_name].append( Metric.from_file(file, importer=importer, dataset=dataset, augmentation=aug) ) except ValueError as e: logger.error(f"Could not parse metrics from {file.resolve()}: {e}") # Add metrics from new Taskcluster .metrics files for file in taskcluster_metrics: model_name = patch_model_name(file.parent.name) try: metric_attrs = parse_gcp_metric(file.stem) metrics[model_name].append( Metric.from_file( file, importer=metric_attrs.importer, dataset=metric_attrs.dataset, augmentation=metric_attrs.augmentation, ) ) except ValueError as e: logger.error(f"Could not parse metrics from {file.resolve()}: {e}") # Publish missing runs (runs without training data) missing_run_metrics = {} if existing_runs is not None: missing_run_metrics = { name: metrics for name, metrics in metrics.items() if name not in existing_runs } for model_name, model_metrics in missing_run_metrics.items(): logger.info(f"Creating missing run {model_name} with associated metrics") publisher = cls( project=project, group=group, name=model_name, suffix=suffix, ) publisher.open(TrainingParser(logs_iter=iter([]), publishers=[])) publisher.handle_metrics(model_metrics) publisher.close() # Publication of the `group_logs` fake run config = {} config_path = Path( "/".join([*logs_parent_folder[:-1], "experiments", project, group, "config.yml"]) ) if not config_path.is_file(): logger.warning(f"No configuration file at {config_path}, skipping.") else: # Publish the YAML configuration as configuration on the group run with config_path.open("r") as f: data = f.read() try: config.update(yaml.safe_load(data)) except Exception as e: logger.error(f"Config could not be read at {config_path}: {e}") publisher = cls( project=project, group=group, name="group_logs", suffix=suffix, ) publisher.wandb = wandb.init( project=project, group=group, name=publisher.run, id=publisher.run, config=config, ) if metrics: # Publish all evaluation metrics to a table table = wandb.Table( columns=["Group", "Model", "Importer", "Dataset", "Augmenation", *METRIC_KEYS], data=[ [group, run_name, metric.importer, metric.dataset, metric.augmentation] + [getattr(metric, attr) for attr in METRIC_KEYS] for run_name, run_metrics in metrics.items() for metric in run_metrics ], ) publisher.wandb.log({"metrics": table}) if logs_dir.is_dir(): # Publish logs directory content as artifacts artifact = wandb.Artifact(name=group, type="logs") artifact.add_dir(local_path=str(logs_dir.resolve())) publisher.wandb.log_artifact(artifact) publisher.wandb.finish()