in training/flax/distil_whisper/partitioner.py [0:0]
def get_mesh_axes(self, train_state: TrainState) -> TrainState:
"""Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
logical_axes = self.get_logical_axes(train_state)
def _logical_to_mesh_axes(param_name, logical_axes):
if logical_axes is None:
return None
elif logical_axes is traverse_util.empty_node:
return traverse_util.empty_node
try:
return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules)
except ValueError as e:
raise ValueError(f"Failed to map logical axes for {param_name}") from e
flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/")
flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()}
return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/"))