in torchrec/distributed/planner/partitioners.py [0:0]
def _partition_by_host(self, sharding_options: List[ShardingOption]) -> None:
# pyre-ignore [16]: `GreedyPerfPartitioner` has no attribute `_topology`.
num_hosts: int = self._topology.world_size // self._topology.local_world_size
mem_cap: List[Storage] = []
partition_sums = []
shard_idxes = []
for option_idx, _ in enumerate(sharding_options):
# only take the first shard from each sharding option. We can infer the rest
shard_idxes.append((option_idx, 0))
host_level_devices: Dict[int, List[DeviceHardware]] = {}
for i in range(num_hosts):
devices_in_host = self._topology.devices[
i
* self._topology.local_world_size : (i + 1)
* self._topology.local_world_size
]
host_level_devices[i] = devices_in_host
# mem_cap of a host is the min of the storage of all devies on that host
mem_cap.append(min([device.storage for device in devices_in_host]))
# perf of a host is the max across all of its devices. Typically this should be zero at entry point.
partition_sums.append(
max([float(device.perf) for device in devices_in_host])
)
host_level_partitions: List[List[Tuple[int, int]]] = greedy_partition(
num_partitions=num_hosts,
sharding_options=sharding_options,
shard_idxes=shard_idxes,
partition_sums=partition_sums,
mem_cap=mem_cap,
)
partitions: List[List[Tuple[int, int]]] = [[] for _ in self._topology.devices]
for host_idx, host_partition in enumerate(host_level_partitions):
self._uniform_device_level_partition(
partitions=partitions,
sharding_options=sharding_options,
option_idxes=[
option_idx
for option_idx, _ in host_partition
if _base_partition_by(sharding_options[option_idx].sharding_type)
== PartitionByType.UNIFORM.value
],
host_level_devices=host_level_devices[host_idx],
host_idx=host_idx,
)
self._greedy_device_level_partition(
partitions=partitions,
sharding_options=sharding_options,
option_idxes=[
option_idx
for option_idx, _ in host_partition
if _base_partition_by(sharding_options[option_idx].sharding_type)
== PartitionByType.DEVICE.value
],
host_level_devices=host_level_devices[host_idx],
host_idx=host_idx,
)
self._update_shards(partitions, sharding_options)