def pad_to_shard()

in text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py [0:0]


def pad_to_shard(env, val, axis: int):
    # if axis is -1, then no sharding is done, everything is replicated
    if axis == -1 or axis is None:
        return val
    sharding = env.sharding_by_axis(axis)
    axis_name = sharding.spec[axis]
    size_to_pad = env.mesh.shape[axis_name]
    padded_val = _pad_array_up_to(val, axis, size_to_pad)
    # Note that, even if the the value is padded, the model will only use the part that is needed.
    if val.shape != padded_val.shape:
        logger.debug(f"Sharding resulting in padding weights from {val.shape} to {padded_val.shape}")
    return padded_val