in src/bert_train.py [0:0]
def snapshot(self, model, model_dir, prefix="best_snaphsot"):
snapshot_prefix = os.path.join(model_dir, prefix)
snapshot_path = snapshot_prefix + 'model.pt'
self._logger.info("Snapshot model to {}".format(snapshot_path))
# If nn.dataparallel, get the underlying module
if isinstance(model, torch.nn.DataParallel):
model = model.module
torch.save(model, snapshot_path)