in modeling/main.py [0:0]
def run_one_epoch(data_type, dataloader, trainer, epoch, run_type, collector=None):
t0 = time.time()
assert data_type in ['dev', 'test']
assert run_type in ['teacher_force', 'generation']
model, optimizer, scheduler, tokenizer = trainer
LOSS, match, bi_match = {'bi': 0, 'lm': 0, 'mention': 0, 'reference': 0, 'total': 0}, [], [] # result container
coref_lines = []
iterator = enumerate(tqdm(dataloader, desc="Epoch {} {}".format(epoch, run_type), disable=args.disable_display))
if args.disable_display:
print('Evaluation progress is not showing')
for step, batch in iterator:
if run_type == 'teacher_force':
loss, _, _, _, _, _, _ = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], \
token_type_ids=batch['token_type_ids'], labels=batch['label_ids'], \
mention_labels=batch['mention_label_ids'], batch=batch, coref_links=batch['coref_label'])
for k, v in loss.items():
LOSS[k] += v.item()
else:
decode_output = decode(args, batch, model, tokenizer)
score_fn(args, decode_output, batch, match, collector, qr_metric, coref_lines, bi_match)
# log
if run_type == 'teacher_force':
for k, v in LOSS.items():
LOSS[k] /= (step+1)
print_loss(epoch, data_type, LOSS, t0)
return LOSS
else: # record decoding result
res = {}
if 'qr' in args.task:
qr_res = qr_metric.get_metric(reset=True)
qr_res['Exact match'] = sum(match) / len(match) * 100
get_binary_res(bi_match, qr_res, args)
res['qr'] = qr_res
else:
res['qr'] = {}
if 'coref' in args.task:
# prepare conll files
key_path = args.dev_conll if data_type == 'dev' else args.test_conll
response_path = 'temp/{}.response'.format(args.model_name) # a temp file for calculating coref score
with open(response_path, 'w') as f:
f.writelines(coref_lines)
res['coref'] = coref_evaluate(key_path, response_path, args)
else:
res['coref'] = {}
print_score(args, epoch, data_type, res, t0)
return res