in src/bert_train.py [0:0]
def try_load_statedict_from_checkpoint(self):
loaded_weights = None
if self.checkpoint_dir is not None:
model_files = list(glob.glob("{}/*.pt".format(self.checkpoint_dir)))
if len(model_files) > 0:
model_file = model_files[0]
self._logger.info(
"Loading checkpoint {} , found {} checkpoint files".format(model_file, len(model_files)))
checkpoint = torch.load(model_file)
loaded_weights = checkpoint['model_state_dict']
return loaded_weights