def push_to_wandb()

in src/nanotron/eval/upload_to_wandb.py [0:0]


def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens):
    s3 = s3fs.S3FileSystem(anon=False)
    all_metrics = {
        # basic X axis replacements for all metrics
        "consumed_tokens": consumed_tokens,
        "train_step": train_step,
    }

    for result_file in sorted(s3.ls(results_path)):
        if not result_file.endswith(".json"):
            continue

        with s3.open(result_file, "r") as f:
            results = json.loads(f.read())["results"]

            for benchmark, metrics in results.items():
                if benchmark == "all":
                    continue

                # extract dataset and config name
                match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark)
                if match:
                    dataset, subtask = match.groups()

                    for metric_name, metric_value in metrics.items():
                        if "_stderr" in metric_name:
                            continue
                        # wandb-friendly metric name
                        wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}"
                        all_metrics[wandb_metric] = metric_value

    run_id = f"{model_name}-{generate_id()}"

    # try to find the run in wandb and resume it
    api = wandb.Api()
    runs = api.runs(f"{wandb_entity}/{wandb_project}")
    for run in runs:
        if run.name == model_name:
            run_id = run.id
            break

    wandb.init(
        project=wandb_project,
        entity=wandb_entity,
        name=model_name,
        id=run_id,
        config={
            "model_name": model_name,
        },
        resume="allow",
    )

    # log all metrics for this checkpoint
    wandb.log(all_metrics)

    wandb.finish()