in EAIEvaluation/HiTUT/hitut_train/custom_train.py [0:0]
def train_one_epoch(model, optimizer, data_loader, counts, loss_weights=None):
model.train()
count = 0
logging.info('Training [Level: %s]' % (model.args.train_level))
epoch_start_time = time.time()
logging.info('************ epoch: %d ************' %(counts['epoch']))
for batch in tqdm(data_loader):
# count += 1
# if count == 30:
# break
if model.args.use_bert and random.random() > 0.5:
continue
task_type = batch['batch_type'][0]
enable_mask = 'mani' in task_type
enable_navi_aux = 'navi' in task_type and model.args.auxiliary_loss_navi
type_loss, arg_loss, mask_loss, navi_losses = model(batch)
navi_loss = sum([l for l in navi_losses.values()]) if enable_navi_aux else 0
if model.args.weigh_loss:
weights = loss_weights[task_type]
w = [torch.exp(-i) for i in weights]
loss = w[0] * type_loss + w[1] * arg_loss + w[2] * mask_loss + w[3] * navi_loss+0.5*weights.sum()
record_loss = (type_loss + arg_loss + mask_loss + navi_loss).item()
else:
loss = type_loss + arg_loss + mask_loss + navi_loss
record_loss = loss.item()
iter_num = counts['iter_%s'%task_type]
local_iter_num = counts['iter_%s'%task_type] % counts['dlen_%s'%task_type]
if local_iter_num %(max(counts['dlen_%s'%task_type]//30,1)) == 0:
lr = optimizer.param_groups[0]["lr"]
mask_str = 'mask: %.4f |'%mask_loss.item() if enable_mask else ''
navi_str = 'vis: %.4f |rea: %.4f |prog: %.4f |'%(navi_losses['visible'].item(),
navi_losses['reached'].item(), navi_losses['progress'].item()) if enable_navi_aux else ''
logging.info('[%8s iter%4d] loss total: %.4f |type: %.4f |arg: %.4f |%s%slr: %.1e'%(
task_type, local_iter_num, record_loss, type_loss.item(), arg_loss.item(), mask_str, navi_str, lr))
writer.add_scalar('train_loss/%s/total'%task_type, record_loss, iter_num)
writer.add_scalar('train_loss/%s/type'%task_type, type_loss.item(), iter_num)
writer.add_scalar('train_loss/%s/arg'%task_type, arg_loss.item(), iter_num)
if enable_mask:
writer.add_scalar('train_loss/low/mask', mask_loss.item(), iter_num)
if enable_navi_aux:
for k, v in navi_losses.items():
writer.add_scalar('train_loss/low/%s'%k, v.item(), iter_num)
if model.args.weigh_loss:
writer.add_scalars('weights/%s'%task_type, {'type': 1/w[0], 'arg': 1/w[1], 'mask': 1/w[2]}, iter_num)
if local_iter_num %(counts['dlen_%s'%task_type]//30 * 6) == 0:
weights = '(' + ' '.join(['%.3f'%i.item() for i in w]) + ')'
logging.info('%s loss weights 1/var: (type, arg, mask)=%s'%(task_type, weights))
writer.flush()
counts['iter_%s'%task_type] += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
if 'scheduler' in counts:
counts['scheduler'].step()
et = '%.1fm'%((time.time() - epoch_start_time)/60)
tt = time.time() - counts['start_time']
tt = '%dh%dm'%(tt//3600, tt//60%60)
logging.info('[%s] epoch %d finished (epoch time: %s | total time: %s)' %(model.args.train_level,
counts['epoch'], et, tt))