in torchrec/distributed/planner/stats.py [0:0]
def _get_sharding_type_abbr(sharding_type: str) -> str:
if sharding_type == ShardingType.DATA_PARALLEL.value:
return "DP"
elif sharding_type == ShardingType.TABLE_WISE.value:
return "TW"
elif sharding_type == ShardingType.COLUMN_WISE.value:
return "CW"
elif sharding_type == ShardingType.ROW_WISE.value:
return "RW"
elif sharding_type == ShardingType.TABLE_ROW_WISE.value:
return "TWRW"
elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value:
return "TWCW"
else:
raise ValueError(
f"Unrecognized or unsupported sharding type provided: {sharding_type}"
)