in src/bert_train.py [0:0]
def create_checkpoint(self, model, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pt')
self._logger.info("Checkpoint model to {}".format(checkpoint_path))
# If nn.dataparallel, get the underlying module
if isinstance(model, torch.nn.DataParallel):
model = model.module
torch.save({
'model_state_dict': model.state_dict(),
}, checkpoint_path)