def shard_weights()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py [0:0]


def shard_weights(env, weights, weight_shardings):
    """Shard weights according to weight_shardings"""
    for k, v in weight_shardings.items():
        logger.debug(f"SHARDING {k} {v}")
    sharded = {}
    for key, val in weights.items():
        axis = weight_shardings.get(key, -1)
        sharding = env.sharding_by_axis(axis)
        with jax.default_device(jax.devices("cpu")[0]):
            # Note we clone to avoid a core-dump that might happen otherwise when calling device_put
            arr = torch_xla2.tensor.t2j(val.clone())
            arr = pad_to_shard(env, arr, axis)
        arr = jax.device_put(arr, sharding)
        sharded[key] = torchjax.to_torch(arr)
    return sharded