def create()

in training/flax/distil_whisper/train_state.py [0:0]


    def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
        other_variables, params = model_variables.pop("params")
        if "params_axes" in other_variables:
            other_variables, params_axes = other_variables.pop("params_axes")
            _validate_params_axes(params_axes, params)
        else:
            params_axes = None

        # Split other_variables into mutables and their corresponding axes.
        flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
        flax_mutables_axes = flax_mutables_axes or None
        return InferenceState(
            step=jnp.array(0),
            params=params,
            params_axes=params_axes,
            flax_mutables=flax_mutables,
            flax_mutables_axes=flax_mutables_axes,
        )