in torchrec/distributed/planner/proposers.py [0:0]
def load(self, search_space: List[ShardingOption]) -> None:
all_fqns = set()
sharding_options_by_type_and_fqn: Dict[
str, Dict[str, List[ShardingOption]]
] = {}
for sharding_option in search_space:
sharding_type = sharding_option.sharding_type
fqn = sharding_option.fqn
all_fqns.add(fqn)
if sharding_type not in sharding_options_by_type_and_fqn:
sharding_options_by_type_and_fqn[sharding_type] = {}
if fqn not in sharding_options_by_type_and_fqn[sharding_type]:
sharding_options_by_type_and_fqn[sharding_type][fqn] = []
sharding_options_by_type_and_fqn[sharding_type][fqn].append(sharding_option)
for sharding_options_by_fqn in sharding_options_by_type_and_fqn.values():
for sharding_options in sharding_options_by_fqn.values():
sharding_options.sort(
key=lambda x: _sharding_option_score(x, self._use_depth)
)
for sharding_options_by_fqn in sharding_options_by_type_and_fqn.values():
if sharding_options_by_fqn.keys() == all_fqns:
self._grouped_sharding_options.append(
[
sorted_sharding_options[0]
for sorted_sharding_options in sharding_options_by_fqn.values()
]
)