in fairscale/experimental/nn/data_parallel/gossip/distributed.py [0:0]
def _init_global_momentum_buffers(self, optimizer: torch.optim.Optimizer) -> None:
"""Initializes the slow momentum buffers"""
self.global_momentum_buffers_initialized = True
if not self.slowmo:
return
total_elements = 0
params_dtype = None
for group in optimizer.param_groups:
for p in group["params"]:
total_elements += p.numel()
# Assert that all parameters have the same device and dtype
if params_dtype is None:
params_dtype, params_device = p.dtype, p.device
# Check that dtype is fp32 since slow mometum is to be performed in fp32
assert p.dtype == params_dtype == torch.float32
assert p.device == params_device
self.world_portion_length = (total_elements + self.slowmo_num_shards - 1) // self.slowmo_num_shards
if not self.is_current_node_a_slowmo_shard:
return
self.portion_start = self.process_rank * self.world_portion_length if self.slowmo_memory_efficient else 0
self.portion_end = (
min((self.process_rank + 1) * self.world_portion_length, total_elements)
if self.slowmo_memory_efficient
else total_elements
)
self.old_params = torch.empty(self.world_portion_length, dtype=params_dtype).to(params_device).detach()
# copy params to old_params to initialize old_params
offset = 0
for group in optimizer.param_groups:
for p in group["params"]:
numel = p.numel()
if offset + numel > self.portion_start and offset < self.portion_end:
# start and end for each
overall_start = max(self.portion_start, offset)
overall_end = min(self.portion_end, offset + numel)
p_start = overall_start - offset
p_end = overall_end - offset
buffer_start = overall_start - self.portion_start
buffer_end = overall_end - self.portion_start
# let's see size of p and split based on that
current_p = p.view(-1)[p_start:p_end]
current_p_old = self.old_params[buffer_start:buffer_end]
current_p_old.copy_(current_p)
offset += numel
self.global_momentum_buffer = torch.zeros_like(self.old_params).detach()