in text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py [0:0]
def shard_weights(env, weights, weight_shardings):
"""Shard weights according to weight_shardings"""
for k, v in weight_shardings.items():
logger.debug(f"SHARDING {k} {v}")
sharded = {}
for key, val in weights.items():
axis = weight_shardings.get(key, -1)
sharding = env.sharding_by_axis(axis)
with jax.default_device(jax.devices("cpu")[0]):
# Note we clone to avoid a core-dump that might happen otherwise when calling device_put
arr = torch_xla2.tensor.t2j(val.clone())
arr = pad_to_shard(env, arr, axis)
arr = jax.device_put(arr, sharding)
sharded[key] = torchjax.to_torch(arr)
return sharded