def snapshot()

in src/bert_train.py [0:0]


    def snapshot(self, model, model_dir, prefix="best_snaphsot"):
        snapshot_prefix = os.path.join(model_dir, prefix)
        snapshot_path = snapshot_prefix + 'model.pt'

        self._logger.info("Snapshot model to {}".format(snapshot_path))

        # If nn.dataparallel, get the underlying module
        if isinstance(model, torch.nn.DataParallel):
            model = model.module

        torch.save(model, snapshot_path)