in src/nanotron/metrics_logging.py [0:0]
def compute_global_hidden_layer_metrics(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]:
"""
Compute global metrics across all hidden layers, excluding embeddings, final layernorm, and lm_head.
Aggregates weights from all decoder layers to provide model-wide statistics.
"""
metrics = {}
max_layers = self.config.model.model_config.num_hidden_layers
if max_layers == 0:
return metrics
formatted_paths = self._format_paths(self.MODEL_COMPONENTS, max_layers)
# Group weights by component type
component_weights = {comp_type: [] for comp_type in self.MODEL_COMPONENTS}
all_layer_weights = []
# Collect all weights from hidden layers
for layer_idx in range(max_layers):
for comp_type, subcomponents in self.MODEL_COMPONENTS.items():
for subcomp_name in subcomponents:
path = formatted_paths[comp_type][subcomp_name][layer_idx]
param = get_attribute_by_path(model, path)
if param is not None:
param_tensor = param.detach().float()
all_layer_weights.append(param_tensor)
component_weights[comp_type].append(param_tensor)
if not all_layer_weights:
return metrics
# Compute statistics for each component type
for comp_type, weights in component_weights.items():
if not weights:
continue
# Flatten tensors for global statistics calculation
flat_tensors = [w.reshape(-1) for w in weights]
all_comp_weights = torch.cat(flat_tensors)
prefix = f"global_{comp_type}"
comp_stats = compute_tensor_stats(all_comp_weights)
for stat_name, value in comp_stats.items():
metrics[f"{prefix}/{stat_name}"] = value
# Compute global stats across all hidden layers
flat_all_tensors = [w.reshape(-1) for w in all_layer_weights]
all_weights = torch.cat(flat_all_tensors)
global_stats = compute_tensor_stats(all_weights)
for stat_name, value in global_stats.items():
metrics[f"global_global/{stat_name}"] = value
return metrics