def move_params_to_devices()

in training/flax/distil_whisper/partitioner.py [0:0]


    def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState:
        """Moves the optimizer parameters to devices."""
        p_id_fn = self.partition(
            _id_fn,
            in_axis_resources=(train_state_axes, None),
            out_axis_resources=(train_state_axes, None),
            donate_argnums=(0,),
        )
        if jax.config.jax_array and jax.process_count() > 1:
            train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes)
        train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
        return train_state