def onload_weights()

in chatlearn/models/megatron/memory_manager/trainer_v3.py [0:0]


    def onload_weights(self):
        """
        onload weights
        """
        if not self._weights_offloaded:
            log_rank_0('Call onload_weights when already onloaded. Ignore it.')
            return

        optimizer = self._optimizer

        # Onload param_data of buffers
        for flat_weights in self._group_flat_weights:
            flat_weights.copy_to_gpu_buffer()

        if self._use_distributed_optimizer:
            # Reconstruct references from buckets
            for buffer in self._buffers:
                assert buffer.param_data is not None
                for bucket_id, bucket in enumerate(buffer.buckets):
                    (start_index, end_index) = buffer.bucket_indices[bucket_id]
                    bucket.param_data = None
                    if buffer.param_data is not None:
                        bucket.param_data = buffer._get(
                            torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM
                        )

            # Reconstruct references from params
            for param, buffer in self._model.param_to_buffer.items():
                data_start_index, _, bucket_id = buffer.param_index_map[param]
                if buffer.param_data is not None:
                    param.data = buffer._get(param._saved_shape, data_start_index, buffer_type=BufferType.PARAM)

        model = self._model
        # Re-register grad acc hooks, see Megatron DistributedDataParallel#__init__.
        model.grad_accs = []
        for param in model.module.parameters():
            if param.requires_grad:
                # Expand so we get access to grad_fn.
                param_tmp = param.expand_as(param)
                # Get the gradient accumulator function.
                grad_acc = param_tmp.grad_fn.next_functions[0][0]
                grad_acc.register_hook(model._make_param_hook(param, model.param_to_buffer))
                model.grad_accs.append(grad_acc)

        if not self._use_distributed_optimizer:
            self._weights_offloaded = False
            return

        optimizer.pbuf_view_items = optimizer._get_model_param_buffer_dp_views()

        shard_float16_groups = optimizer.shard_float16_groups
        shard_fp32_groups = optimizer.shard_fp32_groups
        param_gbuf_map = optimizer.model_param_gbuf_map
        opt_group_ranges = optimizer.opt_group_ranges
        model_gbuf_ranges = optimizer.gbuf_ranges

        # Rebuild shard_float16_groups and shard_fp32_groups,
        # see Megatron DistributedOptimizer#build_model_and_main_param_groups.
        for _, group_range in enumerate(opt_group_ranges):
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)

            for model_param in group_range["params"]:
                assert model_param.requires_grad
                gbuf_index, dtype, bucket_index = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[gbuf_index][dtype][bucket_index]
                param_range = gbuf_range["param_map"][model_param]["param"]

                # fp16, bf16 params.
                if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
                    shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end]
                    tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared

                    shard_float16_params_this_group.append(shard_model_param)

                # fp32 params.
                elif model_param.type() == 'torch.cuda.FloatTensor':
                    shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
                    shard_fp32_params_this_group.append(shard_model_param)
                    tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                else:
                    raise TypeError(
                        'Wrapped parameters must be one of '
                        'torch.cuda.FloatTensor,  '
                        'torch.cuda.HalfTensor, or '
                        'torch.cuda.BFloat16Tensor. '
                        'Received {}'.format(model_param.type())
                    )

        self._weights_offloaded = False