in modules/SwissArmyTransformer/sat/training/model_io.py [0:0]
def load_checkpoint(model, args, load_path=None, prefix='', specific_iteration=None):
"""Load a model checkpoint."""
if load_path is None:
load_path = args.load
# If model-only mode, set necessary args.
if not hasattr(args, 'mode'):
from copy import deepcopy
args = deepcopy(args)
args.mode = 'inference'
# iteration, release, success = get_checkpoint_iteration(load_path)
# if specific_iteration is not None:
# assert type(specific_iteration) == int and specific_iteration > 0
# print_rank0('Overriding checkpoint iteration to {}'.format(specific_iteration))
# iteration = specific_iteration
# if not success:
# return 0
# checkpoint_name = get_checkpoint_name(load_path, iteration, release)
checkpoint_name = os.path.join(load_path, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
iteration = 0
if mpu.get_data_parallel_rank() == 0:
print_all('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# load state_dict into CPU
sd = torch.load(checkpoint_name, map_location='cpu')
# if given `prefix`, load a speficic prefix in the checkpoint, e.g. encoder
new_sd = {'module':{}}
for k in sd:
if k != 'module':
new_sd[k] = sd[k]
for k in sd['module']:
if k.startswith(prefix):
new_sd['module'][k[len(prefix):]] = sd['module'][k]
sd = new_sd
if hasattr(model, 'module'):
module = model.module
else: # inference without deepspeed
module = model
# only load module, other hyperparameters are just for recording.
missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
if len(unexpected_keys) > 0:
print_rank0(
f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
if len(missing_keys) > 0:
if args.mode == 'inference':
if 'force_inference' in args and args.force_inference:
print_rank0(f'Warning: Missing keys for inference: {missing_keys}.')
else:
raise ValueError(f'Missing keys for inference: {missing_keys}.\nIf you still want to inference anyway, pass --force_inference to args.')
else: # new params
if not args.force_train:
assert all(name.find('mixins')>=0 or name.find('cross_attention')>=0 for name in missing_keys), missing_keys
assert args.mode == 'finetune'
# list all mixin names
mixin_names = []
for key_name in missing_keys:
if key_name.find('mixins') < 0:
continue
parts = key_name.split('.')
mixin_name = parts[parts.index('mixins')+1]
if mixin_name not in mixin_names:
mixin_names.append(mixin_name)
module.reinit(mixin_names) # initialize mixins
# Do not need this any more, because we create optimizer after load now.
# if args.mode != 'inference' and args.deepspeed and args.fp16:
# model.optimizer.refresh_fp32_params() # restore fp32 weights from module
# Iterations.
if args.mode == 'finetune':
iteration = 0
elif args.mode == 'pretrain' and not args.no_load_rng: # rng states.
try:
random.setstate(sd['random_rng_state'])
np.random.set_state(sd['np_rng_state'])
torch.set_rng_state(sd['torch_rng_state'])
torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print_rank0('Unable to load optimizer from checkpoint {}, exiting. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the random '
'state.'.format(checkpoint_name))
exit()
elif args.mode == 'inference':
module.eval()
if mpu.get_data_parallel_rank() == 0:
print_all('> successfully loaded {}'.format(checkpoint_name))
del sd
return iteration