def train()

in modeling/main.py [0:0]


def train(args, tokenizer, model):
	# set dataloader
	train_dataloader = set_dataloader(args, tokenizer, 'train', 'teacher_force', data_size=args.train_size)
	dev_dataloader = set_dataloader(args, tokenizer, 'dev', 'teacher_force')
	dev_gen_dataloader = set_dataloader(args, tokenizer, 'dev', 'generation')

	# set optimizer, lr scheduler
	optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
	if args.use_scheduler:
		t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epoch
		scheduler = get_linear_schedule_with_warmup(
			optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
		)
	else:
		scheduler = None
	trainer = (model, optimizer, scheduler, tokenizer)

	print('Test before training!')
	model.eval()
	with torch.no_grad():
		_ = run_one_epoch('dev', dev_dataloader, trainer, -1, 'teacher_force')

	print('Start training!\n{}'.format('***'*30))
	eval_step = args.eval_interval // args.train_batch_size

	# score of query rewrite, corerference resolution, and joint learning (average of two)
	best_score = {'best-QR': -10000, 'best-COREF': -10000, 'best-JOINT': -10000}
	global_step = 0
	no_improve_count = 0
	for epoch in range(args.max_epoch):
		t0 = time.time()
		model.train()
		model.zero_grad()
		LOSS, match = {'bi': 0, 'lm': 0, 'mention': 0, 'reference': 0, 'total': 0}, []
		iterator = enumerate(tqdm(train_dataloader, desc="Epoch {}".format(epoch), disable=args.disable_display))

		if args.disable_display:
			print('Training progress is not showing')

		for local_step, batch in iterator:
			loss, _, _, _, _, _, _ = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], \
											token_type_ids=batch['token_type_ids'], labels=batch['label_ids'], step=None, \
											mention_labels=batch['mention_label_ids'], batch=batch, coref_links=batch['coref_label'])

			for k, v in loss.items():
				LOSS[k] += v.item()
			global_step += 1

			# update model
			if loss['total'].item() != 0:
				loss['total'] = loss['total'] / args.gradient_accumulation_steps
				loss['total'].backward()

			# accumulate gradients
			if global_step % args.gradient_accumulation_steps == 0:
				norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
				optimizer.step()
				if args.use_scheduler:
					scheduler.step()
				optimizer.zero_grad()

			# evaluate model
			if global_step % eval_step == 0 and epoch > 0: 
				model.eval()
				with torch.no_grad():
					loss = run_one_epoch('dev', dev_dataloader, trainer, epoch, 'teacher_force') # get dev loss
					res = run_one_epoch('dev', dev_gen_dataloader, trainer, epoch, 'generation') # get dev result
				model.train()

				# save model
				save_model = save_best_model(args.task, res, best_score)
				if save_model:
					no_improve_count = 0
				else:
					no_improve_count += 1

				# early stop
				if no_improve_count == args.no_improve_max:
					print('Early stop!')
					return

		# get train loss
		for k, v in LOSS.items():
			LOSS[k] /= (local_step+1)
		print_loss(epoch, 'train', LOSS, t0)

		print('***'*30)

	print('Reach max epoch: {}!'.format(args.max_epoch))