in src/nanotron/eval/upload_to_wandb.py [0:0]
def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens):
s3 = s3fs.S3FileSystem(anon=False)
all_metrics = {
# basic X axis replacements for all metrics
"consumed_tokens": consumed_tokens,
"train_step": train_step,
}
for result_file in sorted(s3.ls(results_path)):
if not result_file.endswith(".json"):
continue
with s3.open(result_file, "r") as f:
results = json.loads(f.read())["results"]
for benchmark, metrics in results.items():
if benchmark == "all":
continue
# extract dataset and config name
match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark)
if match:
dataset, subtask = match.groups()
for metric_name, metric_value in metrics.items():
if "_stderr" in metric_name:
continue
# wandb-friendly metric name
wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}"
all_metrics[wandb_metric] = metric_value
run_id = f"{model_name}-{generate_id()}"
# try to find the run in wandb and resume it
api = wandb.Api()
runs = api.runs(f"{wandb_entity}/{wandb_project}")
for run in runs:
if run.name == model_name:
run_id = run.id
break
wandb.init(
project=wandb_project,
entity=wandb_entity,
name=model_name,
id=run_id,
config={
"model_name": model_name,
},
resume="allow",
)
# log all metrics for this checkpoint
wandb.log(all_metrics)
wandb.finish()