def _prepare_analysis()

in var/ramble/repos/builtin/applications/maxtext/application.py [0:0]


    def _prepare_analysis(self, workspace, app_inst):
        """Reads JSON metrics_files output from MaxText and formats them in a new file
        to be processed as FOMs by Ramble."""

        metrics_filename = get_file_path(
            canonicalize_path(
                app_inst.expander.expand_var_name("metrics_file")
            ),
            workspace,
        )

        # Avoid issues with non-existent metrics files
        if not os.path.exists(metrics_filename):
            return

        workflow_node_id = app_inst.expander.expand_var_name(
            "workflow_node_id"
        )
        if workflow_node_id in metrics_filename:
            metrics_filename = metrics_filename.replace(workflow_node_id, "*")
            logger.debug(
                f"Workflow node ID expansion detected. Searching for files with pattern {metrics_filename}"
            )

        metrics_path = os.path.join(
            app_inst.expander.experiment_run_dir, metrics_filename
        )
        metrics_files = sorted(glob.glob(metrics_path))

        if not metrics_files:
            logger.die(
                f"Unable to locate metrics file(s) at:\n    {metrics_path}"
            )

        imported_metrics_data = []
        for file in metrics_files:
            try:
                with open(file) as f:
                    imported_metrics_data.append(f.read().strip())
            except FileNotFoundError:
                logger.debug(f"File not found: {file}")
            except Exception as e:
                logger.debug(f"An error occurred when reading file: {file}\n")
                logger.debug(f"Error: {e}")
        imported_metrics_data = "\n".join(imported_metrics_data)

        aggregated_metrics = {}
        metrics_list = []
        total_tflops = None
        total_weights = None
        num_devices = None

        expected_metrics = {
            "perf/step_time_seconds": "Seconds",
            "perf/per_device_tflops_per_sec": "TFLOP/s/device",
            "perf/per_device_tokens": "Tokens/device",
            "perf/per_device_tokens_per_sec": "Tokens/s/device",
            "learning/loss": "Loss",
            "learning/moe_lb_loss": "MoE Load Balancing Loss",
            "learning/grad_norm": "Grad Norm",
            "learning/param_norm": "Param Norm",
            "learning/raw_grad_norm": "Raw Grad Norm",
            "learning/current_learning_rate": "Current Learning Rate",
        }

        try:
            for line in imported_metrics_data.splitlines():
                line_dict = json.loads(line)
                current_step = line_dict["step"]

                if not total_tflops:
                    total_tflops = line_dict["perf/per_device_tflops"]
                if not total_weights:
                    total_weights = line_dict["learning/total_weights"]

                if current_step not in aggregated_metrics:
                    aggregated_metrics[current_step] = {}

                for metric in expected_metrics.keys():
                    if metric not in aggregated_metrics[current_step]:
                        aggregated_metrics[current_step][metric] = [
                            line_dict[metric]
                        ]
                    else:
                        aggregated_metrics[current_step][metric].append(
                            line_dict[metric]
                        )

            for step, data in aggregated_metrics.items():
                formatted_metrics = [f"Step: {step}"]

                for metric, title in expected_metrics.items():
                    metric_values = data[metric]

                    if metric_values:
                        if not num_devices:
                            num_devices = len(metric_values)

                        mean = stats.StatsMean()
                        formatted_metrics.append(
                            f"Avg {title}: {mean.compute(metric_values)}"
                        )
                    else:
                        logger.debug(
                            "No data found for Step {step} metric {metric}"
                        )

                line_out = ", ".join(formatted_metrics)
                metrics_list.append(line_out)

            metrics_outfile_path = os.path.join(
                app_inst.expander.experiment_run_dir, "metrics.out"
            )

            with open(metrics_outfile_path, "w") as metrics_out:
                metrics_out.write(f"Total TFLOPS: {total_tflops}\n")
                metrics_out.write(f"Total Weights: {total_weights}\n")
                metrics_out.write(
                    f"Total Steps: {max(aggregated_metrics.keys()) + 1}\n"
                )
                metrics_out.write(f"Number of Devices: {num_devices}\n")
                for line in metrics_list:
                    metrics_out.write(line + "\n")

        except Exception as e:
            logger.die(f"Error reading metrics data: {e}")