in text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py [0:0]
def pad_to_shard(env, val, axis: int):
# if axis is -1, then no sharding is done, everything is replicated
if axis == -1 or axis is None:
return val
sharding = env.sharding_by_axis(axis)
axis_name = sharding.spec[axis]
size_to_pad = env.mesh.shape[axis_name]
padded_val = _pad_array_up_to(val, axis, size_to_pad)
# Note that, even if the the value is padded, the model will only use the part that is needed.
if val.shape != padded_val.shape:
logger.debug(f"Sharding resulting in padding weights from {val.shape} to {padded_val.shape}")
return padded_val