in crypten/nn/module.py [0:0]
def _load_from_state_dict_crypten(self, state_dict, prefix, strict):
"""
Copies parameters and buffers from `state_dict` into only this module
but not its children. This is called on every submodule in the
`load_state_dict` function.
"""
# get state dict for just the current module (without children)
local_state = {
key: val for key, val in self.named_parameters() if val is not None
}
# in strict mode, check for missing keys in the state_dict:
if strict:
for name in local_state.keys():
key = prefix + name
if key not in state_dict:
raise ValueError("Key {} not found in state dict.".format(key))
# loop over parameters / buffers in module:
for name, param in local_state.items():
key = prefix + name
input_param = state_dict[key]
# size in state_dict should match size of parameters:
if input_param.size() != param.size():
raise ValueError(
"Size mismatch for {}: copying a param with"
"shape {} from checkpoint, the shape in"
"current model is {}.".format(key, input_param.size(), param.size())
)
continue
# cannot copy encrypted tensors into unencrypted models and vice versa:
param_encrypted = isinstance(input_param, crypten.CrypTensor)
if param_encrypted:
assert (
self.encrypted
), "cannot copy encrypted parameters into unencrypted model"
else:
assert (
not self.encrypted
), "cannot copy unencrypted parameters into encrypted model"
# copy parameters from state_dict:
with crypten.no_grad(), torch.no_grad():
param.copy_(input_param)