in training/flax/distil_whisper/partitioner.py [0:0]
def with_sharding_constraint(x, axis_resources):
"""Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
if jax.devices()[0].platform == "cpu" or not global_mesh_defined():
return x
else:
return jax.experimental.pjit.with_sharding_constraint(x, axis_resources)