def _partition_by_host()

in torchrec/distributed/planner/partitioners.py [0:0]


    def _partition_by_host(self, sharding_options: List[ShardingOption]) -> None:
        # pyre-ignore [16]: `GreedyPerfPartitioner` has no attribute `_topology`.
        num_hosts: int = self._topology.world_size // self._topology.local_world_size
        mem_cap: List[Storage] = []
        partition_sums = []

        shard_idxes = []
        for option_idx, _ in enumerate(sharding_options):
            # only take the first shard from each sharding option. We can infer the rest
            shard_idxes.append((option_idx, 0))

        host_level_devices: Dict[int, List[DeviceHardware]] = {}
        for i in range(num_hosts):
            devices_in_host = self._topology.devices[
                i
                * self._topology.local_world_size : (i + 1)
                * self._topology.local_world_size
            ]
            host_level_devices[i] = devices_in_host

            # mem_cap of a host is the min of the storage of all devies on that host
            mem_cap.append(min([device.storage for device in devices_in_host]))
            # perf of a host is the max across all of its devices. Typically this should be zero at entry point.
            partition_sums.append(
                max([float(device.perf) for device in devices_in_host])
            )

        host_level_partitions: List[List[Tuple[int, int]]] = greedy_partition(
            num_partitions=num_hosts,
            sharding_options=sharding_options,
            shard_idxes=shard_idxes,
            partition_sums=partition_sums,
            mem_cap=mem_cap,
        )
        partitions: List[List[Tuple[int, int]]] = [[] for _ in self._topology.devices]

        for host_idx, host_partition in enumerate(host_level_partitions):

            self._uniform_device_level_partition(
                partitions=partitions,
                sharding_options=sharding_options,
                option_idxes=[
                    option_idx
                    for option_idx, _ in host_partition
                    if _base_partition_by(sharding_options[option_idx].sharding_type)
                    == PartitionByType.UNIFORM.value
                ],
                host_level_devices=host_level_devices[host_idx],
                host_idx=host_idx,
            )

            self._greedy_device_level_partition(
                partitions=partitions,
                sharding_options=sharding_options,
                option_idxes=[
                    option_idx
                    for option_idx, _ in host_partition
                    if _base_partition_by(sharding_options[option_idx].sharding_type)
                    == PartitionByType.DEVICE.value
                ],
                host_level_devices=host_level_devices[host_idx],
                host_idx=host_idx,
            )

        self._update_shards(partitions, sharding_options)