chatlearn/models/megatron/memory_manager/trainer_v1v2.py [82:140]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def offload_weights(self):
        """
        offload weights
        """
        if self._weights_offloaded:
            log_rank_0('Call offload_weights when already offloaded. Ignore it.')
            return

        optimizer = self._optimizer

        if self._use_distributed_optimizer:
            optimizer.shard_float16_groups.clear()
            optimizer.shard_fp32_groups.clear()

        if self._group_flat_weights is None:
            if self._use_distributed_optimizer:
                self._group_flat_weights = self._flat_param_groups(
                    [
                        optimizer.model_float16_groups,
                        optimizer.model_fp32_groups,
                    ],
                )
            else:
                self._group_flat_weights = self._flat_param_groups(
                    [
                        optimizer.float16_groups,
                        optimizer.fp32_from_fp32_groups,
                    ],
                )

        for flat_weights in self._group_flat_weights:
            flat_weights.copy_to_primary_store()

        self._model.grad_accs.clear()

        self._weights_offloaded = True

    def onload_weights(self):
        """
        onload weights
        """
        if not self._weights_offloaded:
            log_rank_0('Call onload_weights when already onloaded. Ignore it.')
            return

        optimizer = self._optimizer

        for flat_weights in self._group_flat_weights:
            flat_weights.copy_to_gpu_buffer()

        model = self._model
        # Re-register grad acc hooks, see Megatron DistributedDataParallel#__init__.
        model.grad_accs = []
        for param in model.module.parameters():
            if param.requires_grad:
                # Expand so we get access to grad_fn.
                param_tmp = param.expand_as(param)
                # Get the gradient accumulator function.
                grad_acc = param_tmp.grad_fn.next_functions[0][0]
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



chatlearn/models/megatron/memory_manager/trainer_v4.py [87:145]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def offload_weights(self):
        """
        offload weights
        """
        if self._weights_offloaded:
            log_rank_0('Call offload_weights when already offloaded. Ignore it.')
            return

        optimizer = self._optimizer

        if self._use_distributed_optimizer:
            optimizer.shard_float16_groups.clear()
            optimizer.shard_fp32_groups.clear()

        if self._group_flat_weights is None:
            if self._use_distributed_optimizer:
                self._group_flat_weights = self._flat_param_groups(
                    [
                        optimizer.model_float16_groups,
                        optimizer.model_fp32_groups,
                    ],
                )
            else:
                self._group_flat_weights = self._flat_param_groups(
                    [
                        optimizer.float16_groups,
                        optimizer.fp32_from_fp32_groups,
                    ],
                )

        for flat_weights in self._group_flat_weights:
            flat_weights.copy_to_primary_store()

        self._model.grad_accs.clear()

        self._weights_offloaded = True

    def onload_weights(self):
        """
        onload weights
        """
        if not self._weights_offloaded:
            log_rank_0('Call onload_weights when already onloaded. Ignore it.')
            return

        optimizer = self._optimizer

        for flat_weights in self._group_flat_weights:
            flat_weights.copy_to_gpu_buffer()

        model = self._model
        # Re-register grad acc hooks, see Megatron DistributedDataParallel#__init__.
        model.grad_accs = []
        for param in model.module.parameters():
            if param.requires_grad:
                # Expand so we get access to grad_fn.
                param_tmp = param.expand_as(param)
                # Get the gradient accumulator function.
                grad_acc = param_tmp.grad_fn.next_functions[0][0]
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



