in training/utils/checkpoint_utils.py [0:0]
def get_state_dict(checkpoint, ckpt_state_dict_keys):
if isinstance(checkpoint, RecursiveScriptModule):
# This is a torchscript JIT model
return checkpoint.state_dict()
pre_train_dict = checkpoint
for i, key in enumerate(ckpt_state_dict_keys):
if (isinstance(pre_train_dict, Mapping) and key not in pre_train_dict) or (
isinstance(pre_train_dict, Sequence) and key >= len(pre_train_dict)
):
key_str = (
'["' + '"]["'.join(list(map(ckpt_state_dict_keys[:i], str))) + '"]'
)
raise KeyError(
f"'{key}' not found in checkpoint{key_str} "
f"with keys: {pre_train_dict.keys()}"
)
pre_train_dict = pre_train_dict[key]
return pre_train_dict