def _init_global_momentum_buffers()

in fairscale/experimental/nn/data_parallel/gossip/distributed.py [0:0]


    def _init_global_momentum_buffers(self, optimizer: torch.optim.Optimizer) -> None:
        """Initializes the slow momentum buffers"""
        self.global_momentum_buffers_initialized = True

        if not self.slowmo:
            return

        total_elements = 0
        params_dtype = None
        for group in optimizer.param_groups:
            for p in group["params"]:
                total_elements += p.numel()

                # Assert that all parameters have the same device and dtype
                if params_dtype is None:
                    params_dtype, params_device = p.dtype, p.device
                # Check that dtype is fp32 since slow mometum is to be performed in fp32
                assert p.dtype == params_dtype == torch.float32
                assert p.device == params_device

        self.world_portion_length = (total_elements + self.slowmo_num_shards - 1) // self.slowmo_num_shards

        if not self.is_current_node_a_slowmo_shard:
            return

        self.portion_start = self.process_rank * self.world_portion_length if self.slowmo_memory_efficient else 0
        self.portion_end = (
            min((self.process_rank + 1) * self.world_portion_length, total_elements)
            if self.slowmo_memory_efficient
            else total_elements
        )

        self.old_params = torch.empty(self.world_portion_length, dtype=params_dtype).to(params_device).detach()

        # copy params to old_params to initialize old_params
        offset = 0
        for group in optimizer.param_groups:
            for p in group["params"]:
                numel = p.numel()

                if offset + numel > self.portion_start and offset < self.portion_end:

                    # start and end for each
                    overall_start = max(self.portion_start, offset)
                    overall_end = min(self.portion_end, offset + numel)

                    p_start = overall_start - offset
                    p_end = overall_end - offset

                    buffer_start = overall_start - self.portion_start
                    buffer_end = overall_end - self.portion_start

                    # let's see size of p and split based on that
                    current_p = p.view(-1)[p_start:p_end]
                    current_p_old = self.old_params[buffer_start:buffer_end]

                    current_p_old.copy_(current_p)

                offset += numel

        self.global_momentum_buffer = torch.zeros_like(self.old_params).detach()