in var/ramble/repos/builtin/applications/maxtext/application.py [0:0]
def _prepare_analysis(self, workspace, app_inst):
"""Reads JSON metrics_files output from MaxText and formats them in a new file
to be processed as FOMs by Ramble."""
metrics_filename = get_file_path(
canonicalize_path(
app_inst.expander.expand_var_name("metrics_file")
),
workspace,
)
# Avoid issues with non-existent metrics files
if not os.path.exists(metrics_filename):
return
workflow_node_id = app_inst.expander.expand_var_name(
"workflow_node_id"
)
if workflow_node_id in metrics_filename:
metrics_filename = metrics_filename.replace(workflow_node_id, "*")
logger.debug(
f"Workflow node ID expansion detected. Searching for files with pattern {metrics_filename}"
)
metrics_path = os.path.join(
app_inst.expander.experiment_run_dir, metrics_filename
)
metrics_files = sorted(glob.glob(metrics_path))
if not metrics_files:
logger.die(
f"Unable to locate metrics file(s) at:\n {metrics_path}"
)
imported_metrics_data = []
for file in metrics_files:
try:
with open(file) as f:
imported_metrics_data.append(f.read().strip())
except FileNotFoundError:
logger.debug(f"File not found: {file}")
except Exception as e:
logger.debug(f"An error occurred when reading file: {file}\n")
logger.debug(f"Error: {e}")
imported_metrics_data = "\n".join(imported_metrics_data)
aggregated_metrics = {}
metrics_list = []
total_tflops = None
total_weights = None
num_devices = None
expected_metrics = {
"perf/step_time_seconds": "Seconds",
"perf/per_device_tflops_per_sec": "TFLOP/s/device",
"perf/per_device_tokens": "Tokens/device",
"perf/per_device_tokens_per_sec": "Tokens/s/device",
"learning/loss": "Loss",
"learning/moe_lb_loss": "MoE Load Balancing Loss",
"learning/grad_norm": "Grad Norm",
"learning/param_norm": "Param Norm",
"learning/raw_grad_norm": "Raw Grad Norm",
"learning/current_learning_rate": "Current Learning Rate",
}
try:
for line in imported_metrics_data.splitlines():
line_dict = json.loads(line)
current_step = line_dict["step"]
if not total_tflops:
total_tflops = line_dict["perf/per_device_tflops"]
if not total_weights:
total_weights = line_dict["learning/total_weights"]
if current_step not in aggregated_metrics:
aggregated_metrics[current_step] = {}
for metric in expected_metrics.keys():
if metric not in aggregated_metrics[current_step]:
aggregated_metrics[current_step][metric] = [
line_dict[metric]
]
else:
aggregated_metrics[current_step][metric].append(
line_dict[metric]
)
for step, data in aggregated_metrics.items():
formatted_metrics = [f"Step: {step}"]
for metric, title in expected_metrics.items():
metric_values = data[metric]
if metric_values:
if not num_devices:
num_devices = len(metric_values)
mean = stats.StatsMean()
formatted_metrics.append(
f"Avg {title}: {mean.compute(metric_values)}"
)
else:
logger.debug(
"No data found for Step {step} metric {metric}"
)
line_out = ", ".join(formatted_metrics)
metrics_list.append(line_out)
metrics_outfile_path = os.path.join(
app_inst.expander.experiment_run_dir, "metrics.out"
)
with open(metrics_outfile_path, "w") as metrics_out:
metrics_out.write(f"Total TFLOPS: {total_tflops}\n")
metrics_out.write(f"Total Weights: {total_weights}\n")
metrics_out.write(
f"Total Steps: {max(aggregated_metrics.keys()) + 1}\n"
)
metrics_out.write(f"Number of Devices: {num_devices}\n")
for line in metrics_list:
metrics_out.write(line + "\n")
except Exception as e:
logger.die(f"Error reading metrics data: {e}")