def lazy_init_with_master_weights()

in apex/apex/amp/_process_optimizer.py [0:0]


def lazy_init_with_master_weights(self):
        stash = self._amp_stash
        stash.fp16_groups = []
        stash.fp32_from_fp16_groups = []
        stash.fp32_from_fp32_groups = []
        for i, param_group in enumerate(self.param_groups):
            # maybe_print("FP16_Optimizer processing param group {}:".format(i))
            fp16_params_this_group = []
            fp32_params_this_group = []
            fp32_from_fp16_params_this_group = []
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:
                    if param.type() == 'torch.cuda.HalfTensor':
                        # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
                        #             .format(param.size()))
                        fp16_params_this_group.append(param)
                        master_param = param.detach().clone().float()
                        master_param.requires_grad = True
                        param_group['params'][i] = master_param
                        fp32_from_fp16_params_this_group.append(master_param)
                        # Reset existing state dict key to the new master param.
                        # We still need to recast per-param state tensors, if any, to FP32.
                        if param in self.state:
                           self.state[master_param] = self.state.pop(param)
                    elif param.type() == 'torch.cuda.FloatTensor':
                        # maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
                        #             .format(param.size()))
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param
                    else:
                        raise TypeError("Optimizer's parameters must be either "
                                        "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
                                        "Received {}".format(param.type()))

            stash.fp16_groups.append(fp16_params_this_group)
            stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
            stash.fp32_from_fp32_groups.append(fp32_params_this_group)

        stash.all_fp16_params = []
        for group in stash.fp16_groups:
            stash.all_fp16_params += group

        stash.all_fp32_from_fp16_params = []
        for group in stash.fp32_from_fp16_groups:
            stash.all_fp32_from_fp16_params += group

        stash.all_fp32_from_fp32_params = []
        for group in stash.fp32_from_fp32_groups:
            stash.all_fp32_from_fp32_params += group

        # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
        stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]

        for param in stash.all_fp32_from_fp16_params:
            param.grad = None

        for param in stash.all_fp32_from_fp32_params:
            param.grad = None

        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
        self.load_state_dict(self.state_dict())