def _load_from_state_dict_crypten()

in crypten/nn/module.py [0:0]


    def _load_from_state_dict_crypten(self, state_dict, prefix, strict):
        """
        Copies parameters and buffers from `state_dict` into only this module
        but not its children. This is called on every submodule in the
        `load_state_dict` function.
        """

        # get state dict for just the current module (without children)
        local_state = {
            key: val for key, val in self.named_parameters() if val is not None
        }

        # in strict mode, check for missing keys in the state_dict:
        if strict:
            for name in local_state.keys():
                key = prefix + name
                if key not in state_dict:
                    raise ValueError("Key {} not found in state dict.".format(key))

        # loop over parameters / buffers in module:
        for name, param in local_state.items():
            key = prefix + name
            input_param = state_dict[key]

            # size in state_dict should match size of parameters:
            if input_param.size() != param.size():
                raise ValueError(
                    "Size mismatch for {}: copying a param with"
                    "shape {} from checkpoint, the shape in"
                    "current model is {}.".format(key, input_param.size(), param.size())
                )
                continue

            # cannot copy encrypted tensors into unencrypted models and vice versa:
            param_encrypted = isinstance(input_param, crypten.CrypTensor)
            if param_encrypted:
                assert (
                    self.encrypted
                ), "cannot copy encrypted parameters into unencrypted model"
            else:
                assert (
                    not self.encrypted
                ), "cannot copy unencrypted parameters into encrypted model"

            # copy parameters from state_dict:
            with crypten.no_grad(), torch.no_grad():
                param.copy_(input_param)