in run_ranking.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', nargs='*')
parser.add_argument('--pretrained_model', default='bert-base-uncased', type=str)
parser.add_argument('--overwrite', default=None, nargs='*')
parser.add_argument('--do_train', action='store_true')
parser.add_argument('--do_test', action='store_true')
parser.add_argument('--data_dir', default='/home/ec2-user/efs/ott-qa/', type=str)
parser.add_argument('--num_cand', default=64, type=int)
parser.add_argument('--question_type', default='ott-qa', choices=['opensquad', 'wikisql_denotation', 'NQ-open', 'ott-qa'])
parser.add_argument('--task_type', default='AutoModelForSequenceClassification', type=str)
parser.add_argument('--checkpoint_dir', default='/home/ec2-user/efs/ck/ott-qa/', type=str)
parser.add_argument('--cache_dir', default='/home/ec2-user/efs/cache/', type=str)
parser.add_argument('--tensorboard_dir', default='/home/ec2-user/efs/wandb/', type=str)
parser.add_argument('--load_model_checkpoint', default=None, type=str, help="The checkpoint file upon which you want to continue training on.")
parser.add_argument('--weight_decay', default=0.01, type=float)
parser.add_argument('--learning_rate', default=5e-5, type=float)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--test_batch_size', default=1024, type=int)
parser.add_argument('--adam_epsilon', default=1e-8, type=float)
parser.add_argument('--lr_schedule', default='linear', type=str, choices=['linear', 'cosine', 'cosine_hard', 'constant'])
parser.add_argument('--warmup_steps', default=0.1, type=float, help="if < 1, it means fraction; otherwise, means number of steps")
parser.add_argument('--gradient_accumulation_steps', default=2, type=int)
parser.add_argument('--num_train_epochs', default=3, type=int)
parser.add_argument('--num_train_steps', default=10000, type=int)
parser.add_argument('--max_length', default=150, type=int)
# add all the available options to the trainer
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.question_type)
logger.info(f"Checkpoint directory: {args.checkpoint_dir}")
if args.gpu == None:
logger.info("not using GPU")
args.gpu = 0
else:
try:
args.gpu = [int(x) for x in args.gpu]
logger.info(f"using gpu {args.gpu}")
except:
ValueError("only support numerical values")
# read pretrained model and tokenizer using config
logger.info("loading pretrained model and tokenizer")
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.task_type]
config = config_class.from_pretrained(args.pretrained_model, cache_dir=args.cache_dir)
config.num_labels = 1
tokenizer = tokenizer_class.from_pretrained(args.pretrained_model, use_fast=True, cache_dir=args.cache_dir)
model = model_class.from_pretrained(
args.pretrained_model,
from_tf=False,
config=config,
cache_dir=args.cache_dir)
# add special tokens
additional_special_tokens_dict = {'additional_special_tokens': ['[title]']}
tokenizer.add_special_tokens(additional_special_tokens_dict) # add classification tokens
model.resize_token_embeddings(len(tokenizer))
if args.overwrite is None:
args.overwrite = []
# checkpoint
checkpoint_dir = os.path.join(args.checkpoint_dir, f'{args.pretrained_model}/')
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
checkpoint_callback = ModelCheckpoint(monitor='avg_val_performance', filepath=checkpoint_dir+'{epoch}-{val_loss:.4f}-{avg_val_performance:.4f}', mode='max')
# training and testing
if args.do_train:
# initialized dataloaders
train_dataloader = generate_dataloader(
args.data_dir,
tokenizer,
args.max_length,
'train',
args.num_cand,
'train' in args.overwrite,
args.batch_size,
args.question_type,
)
val_dataloader = generate_dataloader(
args.data_dir,
tokenizer,
args.max_length,
'dev',
args.num_cand,
'dev' in args.overwrite,
args.batch_size,
args.question_type,
)
# test_dataloader = generate_dataloader(
# args.data_dir,
# tokenizer,
# args.max_length,
# 'test',
# args.num_cand,
# 'test' in args.overwrite,
# args.batch_size,
# args.question_type,
# )
# test_dataloader = None
if args.num_train_steps <= 0:
args.num_train_steps = len(train_dataloader) * args.num_train_epochs
bert_ranker = Reranker(model, tokenizer, args)
if args.load_model_checkpoint != None:
logger.info(f"Loading the checkpoint {args.load_model_checkpoint} and continue training")
model_checkpoint = torch.load(args.load_model_checkpoint, map_location=lambda storage, loc: storage)
model_dict = model_checkpoint['state_dict']
bert_ranker.load_state_dict(model_dict)
tb_logger = loggers.WandbLogger(save_dir=args.tensorboard_dir, project='hybridQA-ott-qa')
trainer = pl.Trainer(logger=tb_logger,
checkpoint_callback=checkpoint_callback,
gpus=args.gpu,
distributed_backend='dp',
val_check_interval=0.25, # check every certain % of an epoch
# min_epochs=args.num_train_epochs,
max_epochs=args.num_train_epochs,
max_steps=args.num_train_steps,
accumulate_grad_batches=args.gradient_accumulation_steps,
gradient_clip_val=1.0,
precision=args.precision) # train
trainer.fit(bert_ranker, train_dataloader, val_dataloader)
# trainer.test(bert_ranker)
if args.do_test:
torch.cuda.empty_cache()
# initialized dataloaders
test_dataloader = generate_dataloader(
args.data_dir,
tokenizer,
args.max_length,
'test',
args.num_cand,
'test' in args.overwrite,
args.test_batch_size,
args.question_type,
)
if args.load_model_checkpoint:
best_checkpoint_file = args.load_model_checkpoint
else:
# find best checkpoint
best_val_performance = -100.
best_val_loss = 100.
for checkpoint_file in glob.glob(checkpoint_dir+"*avg_val_performance*.ckpt"):
val_performance = float(checkpoint_file.split('=')[-1].replace('.ckpt',''))
val_loss = float(checkpoint_file.split('=')[-2].split('-')[0])
if val_performance > best_val_performance:
best_val_performance = val_performance
best_val_loss = val_loss
best_checkpoint_file = checkpoint_file
logger.info(f"Loading the checkpoint: {best_checkpoint_file}")
# load model
bert_ranker = RerankerInference(model, tokenizer, args)
best_checkpoint = torch.load(best_checkpoint_file, map_location=lambda storage, loc: storage)
bert_ranker.load_state_dict(best_checkpoint['state_dict'])
# test using Trainer test function
trainer = pl.Trainer(gpus=args.gpu, distributed_backend='dp', benchmark=True)
trainer.test(bert_ranker, test_dataloader)