in scripts/reader/train.py [0:0]
def main(args):
# --------------------------------------------------------------------------
# DATA
logger.info('-' * 100)
logger.info('Load data files')
train_exs = utils.load_data(args, args.train_file, skip_no_answer=True)
logger.info('Num train examples = %d' % len(train_exs))
dev_exs = utils.load_data(args, args.dev_file)
logger.info('Num dev examples = %d' % len(dev_exs))
# If we are doing offician evals then we need to:
# 1) Load the original text to retrieve spans from offsets.
# 2) Load the (multiple) text answers for each question.
if args.official_eval:
dev_texts = utils.load_text(args.dev_json)
dev_offsets = {ex['id']: ex['offsets'] for ex in dev_exs}
dev_answers = utils.load_answers(args.dev_json)
# --------------------------------------------------------------------------
# MODEL
logger.info('-' * 100)
start_epoch = 0
if args.checkpoint and os.path.isfile(args.model_file + '.checkpoint'):
# Just resume training, no modifications.
logger.info('Found a checkpoint...')
checkpoint_file = args.model_file + '.checkpoint'
model, start_epoch = DocReader.load_checkpoint(checkpoint_file, args)
else:
# Training starts fresh. But the model state is either pretrained or
# newly (randomly) initialized.
if args.pretrained:
logger.info('Using pretrained model...')
model = DocReader.load(args.pretrained, args)
if args.expand_dictionary:
logger.info('Expanding dictionary for new data...')
# Add words in training + dev examples
words = utils.load_words(args, train_exs + dev_exs)
added = model.expand_dictionary(words)
# Load pretrained embeddings for added words
if args.embedding_file:
model.load_embeddings(added, args.embedding_file)
else:
logger.info('Training model from scratch...')
model = init_from_scratch(args, train_exs, dev_exs)
# Set up partial tuning of embeddings
if args.tune_partial > 0:
logger.info('-' * 100)
logger.info('Counting %d most frequent question words' %
args.tune_partial)
top_words = utils.top_question_words(
args, train_exs, model.word_dict
)
for word in top_words[:5]:
logger.info(word)
logger.info('...')
for word in top_words[-6:-1]:
logger.info(word)
model.tune_embeddings([w[0] for w in top_words])
# Set up optimizer
model.init_optimizer()
# Use the GPU?
if args.cuda:
model.cuda()
# Use multiple GPUs?
if args.parallel:
model.parallelize()
# --------------------------------------------------------------------------
# DATA ITERATORS
# Two datasets: train and dev. If we sort by length it's faster.
logger.info('-' * 100)
logger.info('Make data loaders')
train_dataset = data.ReaderDataset(train_exs, model, single_answer=True)
if args.sort_by_len:
train_sampler = data.SortedBatchSampler(train_dataset.lengths(),
args.batch_size,
shuffle=True)
else:
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.data_workers,
collate_fn=vector.batchify,
pin_memory=args.cuda,
)
dev_dataset = data.ReaderDataset(dev_exs, model, single_answer=False)
if args.sort_by_len:
dev_sampler = data.SortedBatchSampler(dev_dataset.lengths(),
args.test_batch_size,
shuffle=False)
else:
dev_sampler = torch.utils.data.sampler.SequentialSampler(dev_dataset)
dev_loader = torch.utils.data.DataLoader(
dev_dataset,
batch_size=args.test_batch_size,
sampler=dev_sampler,
num_workers=args.data_workers,
collate_fn=vector.batchify,
pin_memory=args.cuda,
)
# -------------------------------------------------------------------------
# PRINT CONFIG
logger.info('-' * 100)
logger.info('CONFIG:\n%s' %
json.dumps(vars(args), indent=4, sort_keys=True))
# --------------------------------------------------------------------------
# TRAIN/VALID LOOP
logger.info('-' * 100)
logger.info('Starting training...')
stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
for epoch in range(start_epoch, args.num_epochs):
stats['epoch'] = epoch
# Train
train(args, train_loader, model, stats)
# Validate unofficial (train)
validate_unofficial(args, train_loader, model, stats, mode='train')
# Validate unofficial (dev)
result = validate_unofficial(args, dev_loader, model, stats, mode='dev')
# Validate official
if args.official_eval:
result = validate_official(args, dev_loader, model, stats,
dev_offsets, dev_texts, dev_answers)
# Save best valid
if result[args.valid_metric] > stats['best_valid']:
logger.info('Best valid: %s = %.2f (epoch %d, %d updates)' %
(args.valid_metric, result[args.valid_metric],
stats['epoch'], model.updates))
model.save(args.model_file)
stats['best_valid'] = result[args.valid_metric]