in training/flax/distil_whisper/partitioner.py [0:0]
def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState:
"""Moves the optimizer parameters to devices."""
p_id_fn = self.partition(
_id_fn,
in_axis_resources=(train_state_axes, None),
out_axis_resources=(train_state_axes, None),
donate_argnums=(0,),
)
if jax.config.jax_array and jax.process_count() > 1:
train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes)
train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
return train_state