def write_metric()

in training/flax/run_finetuning.py [0:0]


def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step, logging_steps):
    summary_writer.scalar("train/time", train_time, step)

    train_metrics = get_metrics(train_metrics)
    for key, vals in train_metrics.items():
        steps_arr = np.arange(0, step, logging_steps)[-len(vals) :]
        tag = f"train/{key}"
        for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, steps_arr[i])

    for metric_name, value in eval_metrics.items():
        summary_writer.scalar(f"eval/{metric_name}", value, step)