in vision/m4/evaluation/scripts/sync_evaluations_on_wandb.py [0:0]
def main():
logging.basicConfig(
level=logging.INFO,
format=" - %(process)d - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
args = get_args()
logger.info(f"args: {args}")
api = wandb.Api(timeout=WANDB_TIMEOUT)
def filter_out_tags(tags):
if "debug" in tags or "failed" in tags or "killed" in tags:
return False
for t in tags:
if "job_id" in t:
return False
return True
def fetch_training_run(training_run_name):
"""
Fetch training run. There can only be one corresponding training run.
If not, double check the tags (killed, failed, etc.)
"""
matching_runs = []
runs = api.runs(f"{args.wandb_entity}/{args.wandb_training_project}")
for run in runs:
if run.name == training_run_name:
matching_runs.append(run)
matching_runs = [r for r in matching_runs if filter_out_tags(r.tags)]
assert len(matching_runs) == 1, f"There are 0 or more than 1 matching runs: {matching_runs}"
return matching_runs[0]
def fetch_evaluation_run(evaluation_run_name):
"""
Fetch evaluation run. There can only be one corresponding evaluation run at most.
If not, double check the tags (killed, failed, etc.)
"""
matching_runs = []
runs = api.runs(f"{args.wandb_entity}/{args.wandb_eval_project}")
for run in runs:
if run.name == evaluation_run_name:
matching_runs.append(run)
matching_runs = [r for r in matching_runs if filter_out_tags(r.tags)]
assert len(matching_runs) <= 1, f"There are more than 2 matching runs: {matching_runs}"
if len(matching_runs) == 0:
return None
else:
return matching_runs[0]
training_run = fetch_training_run(args.run_name_to_log)
logger.info("Successfully fetched the training run.")
evaluation_run = fetch_evaluation_run(args.run_name_to_log)
logger.info("Successfully fetched the (potentially `None`) evaluation run.")
def get_logged_eval_values(evaluation_run):
"""
If `evaluation_run` already exists, get the already logged values into a dictionary.
"""
logged_evaluation_values = defaultdict()
if evaluation_run is not None:
for row in evaluation_run.scan_history():
opt_step = row[OPT_STEP_LOG]
logged_evaluation_values[opt_step] = row
return logged_evaluation_values
already_logged_eval_values = get_logged_eval_values(evaluation_run)
logger.info(f"LOGGED_VALUES: {already_logged_eval_values}")
def get_evaluations_values_from_json():
"""
Load all values from the json file
"""
evaluation_values = defaultdict(lambda: defaultdict())
for evaluation_jsonl_file in args.evaluation_jsonl_files:
with open(evaluation_jsonl_file, "r") as f:
for line in f.readlines():
evaluation = json.loads(line)
opt_step = int(evaluation["model_name_or_path"].split("/opt_step-")[1].split("/")[0])
task = evaluation["task"]
for metric, value in eval(evaluation["score"]).items():
metric_name = f"{task}-{metric}"
if "_distribution" in metric_name:
assert isinstance(
value, list
), f"Don't know how to handle metric {metric_name} of type {type(value)} | {value}"
evaluation_values[opt_step][metric_name] = wandb.Histogram(value)
elif isinstance(value, float) or isinstance(value, int):
evaluation_values[opt_step][metric_name] = value
else:
raise ValueError(
f"Don't know how to handle metric {metric_name} of type {type(value)} | {value}"
)
return evaluation_values
evaluation_metrics = get_evaluations_values_from_json()
logger.info(f"Evaluation values: {evaluation_metrics}")
def filter_out_columns(row):
return {
k: v
for k, v in row.items()
if ("gradients/" not in k and "parameters/" not in k and not k.startswith("_"))
}
def convert_training_run_to_dict(training_run):
"""
Get all the logged values from the training into a dictionary.
"""
training_history = training_run.scan_history()
d = defaultdict(dict)
for row in training_history:
if "num_opt_steps" not in row:
continue
row = filter_out_columns(row)
opt_step = row[OPT_STEP_LOG]
assert opt_step not in d, (
f"The current code does not support having multiple entries for a single `opt_step` ({opt_step})."
" Please double check what's happening, and if necessary, support this case (for instance by only"
" considering the entry with the last timestamp.)"
)
d[opt_step] = row
return d
training_dict = convert_training_run_to_dict(training_run)
# Add values from json file to the `training_dict`
for opt_step, eval_metrics_for_opt_step in evaluation_metrics.items():
if opt_step in training_dict:
for k, v in eval_metrics_for_opt_step.items():
assert k not in training_dict[opt_step]
training_dict[opt_step][k] = v
else:
# This case only happens when we are saving a checkpoint without logging metrics on wandb.
# If `train_saving_opt_steps` is a multiple of `wandb_log_freq`, this will happens when we enter the
# manual exit conditions, and then evaluate this checkpoint.
training_dict[opt_step] = eval_metrics_for_opt_step
training_dict[opt_step][OPT_STEP_LOG] = opt_step
# Go through the `training_dict` and check for compatibilities with already logged runs
if evaluation_run is not None:
for opt_step, training_metrics_for_opt_step in training_dict.items():
if opt_step not in already_logged_eval_values:
continue
for metric_name, metric_value in training_metrics_for_opt_step.items():
if metric_name in already_logged_eval_values[opt_step]:
print("already logged")
if isinstance(metric_value, wandb.Histogram):
if already_logged_eval_values[opt_step][metric_name]["_type"] != "histogram":
msg = (
"You are trying to log a histogram but the metric logged previously is not a"
" histogram: YOU SHOULD CHECK!"
)
raise ValueError(msg)
elif (
metric_value.to_json()["values"]
!= already_logged_eval_values[opt_step][metric_name]["values"]
):
msg = (
"values already logged are different from the new ones \nBef:"
f" {already_logged_eval_values[opt_step][metric_name]['values']}\nAft:"
f" {metric_value.to_json()['values']}"
)
raise ValueError(msg)
elif (
already_logged_eval_values[opt_step][metric_name] != metric_value
and metric_value is not None
and not np.isnan(metric_value)
):
raise ValueError("YOU SHOULD CHECK!!")
def get_wandb_logger(evaluation_run):
"""
Init the wandb logger.
"""
if evaluation_run is not None:
print("Resuming wandb run")
wandb_logger = wandb.init(
resume=None,
project=args.wandb_eval_project,
entity=args.wandb_entity,
name=args.run_name_to_log,
allow_val_change=True,
id=evaluation_run.id,
)
else:
wandb_id = wandb.util.generate_id()
print(f"Creating wandb run with id {wandb_id}")
wandb_logger = wandb.init(
resume=None,
project=args.wandb_eval_project,
entity=args.wandb_entity,
name=args.run_name_to_log,
allow_val_change=True,
id=wandb_id,
)
return wandb_logger
wandb_logger = get_wandb_logger(evaluation_run)
for v in training_dict.values():
assert OPT_STEP_LOG in v
wandb_logger.log(v)
sleep(1)
wandb.finish(quiet=True)
logger.info("Finished wandb sync")