in training/flax/distil_whisper/train_state.py [0:0]
def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
other_variables, params = model_variables.pop("params")
if "params_axes" in other_variables:
other_variables, params_axes = other_variables.pop("params_axes")
_validate_params_axes(params_axes, params)
else:
params_axes = None
# Split other_variables into mutables and their corresponding axes.
flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
flax_mutables_axes = flax_mutables_axes or None
return InferenceState(
step=jnp.array(0),
params=params,
params_axes=params_axes,
flax_mutables=flax_mutables,
flax_mutables_axes=flax_mutables_axes,
)