in fairscale/experimental/nn/offload.py [0:0]
def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]:
number_splits = min(len(modules), number_splits)
splits: List[List[nn.Module]] = [[] for _ in range(number_splits)]
# Count the number of parameters per exposed layer, use that as a proxy for memory footprint
total_number_params = sum([sum(p.numel() for p in m.parameters()) for m in modules])
number_parameters_per_shard = total_number_params // number_splits
current_shard = 0
logging.info(
f"This model has {total_number_params/1e6:.2f}M parameters, aiming for {number_parameters_per_shard/1e6:.2f}M parameters per shard"
)
for m in modules:
for p in m.parameters():
p.data = p.data.pin_memory()
# Number of parameters in the current shard
current_shard_params = sum(p.numel() for sm in splits[current_shard] for p in sm.parameters())
# This shard is big enough, point to the next one
if (
current_shard_params > 0
and current_shard_params + sum(p.numel() for p in m.parameters()) > number_parameters_per_shard
and current_shard < number_splits - 1
):
current_shard += 1
splits[current_shard].append(m)
for i, split in enumerate(splits):
current_shard_params = sum(p.numel() for sm in split for p in sm.parameters())
logging.info(f"Shard {i} holds {current_shard_params/1e6:.2f}M parameters")
return splits