def with_sharding_constraint()

in training/flax/distil_whisper/partitioner.py [0:0]


def with_sharding_constraint(x, axis_resources):
    """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
    if jax.devices()[0].platform == "cpu" or not global_mesh_defined():
        return x
    else:
        return jax.experimental.pjit.with_sharding_constraint(x, axis_resources)