def run_one_epoch()

in modeling/main.py [0:0]


def run_one_epoch(data_type, dataloader, trainer, epoch, run_type, collector=None):
	t0 = time.time()
	assert data_type in ['dev', 'test']
	assert run_type in ['teacher_force', 'generation']
	model, optimizer, scheduler, tokenizer = trainer

	LOSS, match, bi_match = {'bi': 0, 'lm': 0, 'mention': 0, 'reference': 0, 'total': 0}, [], [] # result container
	coref_lines = []
	iterator = enumerate(tqdm(dataloader, desc="Epoch {} {}".format(epoch, run_type), disable=args.disable_display))

	if args.disable_display:
		print('Evaluation progress is not showing')

	for step, batch in iterator:
		if run_type == 'teacher_force':
			loss, _, _, _, _, _, _ = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], \
											token_type_ids=batch['token_type_ids'], labels=batch['label_ids'], \
											mention_labels=batch['mention_label_ids'], batch=batch, coref_links=batch['coref_label'])
			for k, v in loss.items():
				LOSS[k] += v.item()

		else:
			decode_output = decode(args, batch, model, tokenizer)
			score_fn(args, decode_output, batch, match, collector, qr_metric, coref_lines, bi_match)

	# log
	if run_type == 'teacher_force':
		for k, v in LOSS.items():
			LOSS[k] /= (step+1)
		print_loss(epoch, data_type, LOSS, t0)
		return LOSS
	else: # record decoding result
		res = {}
		if 'qr' in args.task:
			qr_res = qr_metric.get_metric(reset=True)
			qr_res['Exact match'] = sum(match) / len(match) * 100
			get_binary_res(bi_match, qr_res, args)
			res['qr'] = qr_res
		else:
			res['qr'] = {}

		if 'coref' in args.task:
			# prepare conll files
			key_path = args.dev_conll if data_type == 'dev' else args.test_conll
			response_path = 'temp/{}.response'.format(args.model_name) # a temp file for calculating coref score
			with open(response_path, 'w') as f:
				f.writelines(coref_lines)
			res['coref'] = coref_evaluate(key_path, response_path, args)
		else:
			res['coref'] = {}

		print_score(args, epoch, data_type, res, t0)
		return res