def create_checkpoint()

in src/bert_train.py [0:0]


    def create_checkpoint(self, model, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pt')

        self._logger.info("Checkpoint model to {}".format(checkpoint_path))

        # If nn.dataparallel, get the underlying module
        if isinstance(model, torch.nn.DataParallel):
            model = model.module

        torch.save({
            'model_state_dict': model.state_dict(),
        }, checkpoint_path)