def compute_global_hidden_layer_metrics()

in src/nanotron/metrics_logging.py [0:0]


    def compute_global_hidden_layer_metrics(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
        """
        Compute global metrics across all hidden layers, excluding embeddings, final layernorm, and lm_head.
        Aggregates weights from all decoder layers to provide model-wide statistics.
        """
        metrics = {}
        max_layers = self.config.model.model_config.num_hidden_layers

        if max_layers == 0:
            return metrics

        formatted_paths = self._format_paths(self.MODEL_COMPONENTS, max_layers)

        # Group weights by component type
        component_weights = {comp_type: [] for comp_type in self.MODEL_COMPONENTS}
        all_layer_weights = []

        # Collect all weights from hidden layers
        for layer_idx in range(max_layers):
            for comp_type, subcomponents in self.MODEL_COMPONENTS.items():
                for subcomp_name in subcomponents:
                    path = formatted_paths[comp_type][subcomp_name][layer_idx]
                    param = get_attribute_by_path(model, path)
                    if param is not None:
                        param_tensor = param.detach().float()
                        all_layer_weights.append(param_tensor)
                        component_weights[comp_type].append(param_tensor)

        if not all_layer_weights:
            return metrics

        # Compute statistics for each component type
        for comp_type, weights in component_weights.items():
            if not weights:
                continue

            # Flatten tensors for global statistics calculation
            flat_tensors = [w.reshape(-1) for w in weights]
            all_comp_weights = torch.cat(flat_tensors)
            prefix = f"global_{comp_type}"

            comp_stats = compute_tensor_stats(all_comp_weights)
            for stat_name, value in comp_stats.items():
                metrics[f"{prefix}/{stat_name}"] = value

        # Compute global stats across all hidden layers
        flat_all_tensors = [w.reshape(-1) for w in all_layer_weights]
        all_weights = torch.cat(flat_all_tensors)

        global_stats = compute_tensor_stats(all_weights)
        for stat_name, value in global_stats.items():
            metrics[f"global_global/{stat_name}"] = value

        return metrics