def _jnp2np()

in bindings/python/py_src/safetensors/flax.py [0:0]


def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.array]:
    for k, v in jnp_dict.items():
        jnp_dict[k] = np.asarray(v)
    return jnp_dict