in captioning/modules/loss_wrapper_joint.py [0:0]
def forward(self, fc_feats, att_feats, trace_feats, box_feats, labels, masks, att_masks, trace_masks, gts, gt_indices,
sc_flag, struc_flag):
opt = self.opt
out = {}
if struc_flag:
if opt.structure_loss_weight < 1:
lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
else:
lm_loss = torch.tensor(0).type_as(fc_feats)
if opt.structure_loss_weight > 0:
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
opt={'sample_method':opt.train_sample_method,
'beam_size':opt.train_beam_size,
'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
or not 'margin' in opt.structure_loss_type,
'sample_n': opt.train_sample_n},
mode='sample')
gts = [gts[_] for _ in gt_indices.tolist()]
struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
else:
struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
'reward': torch.tensor(0).type_as(fc_feats)}
loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
out['lm_loss'] = lm_loss
out['struc_loss'] = struc_loss['loss']
out['reward'] = struc_loss['reward']
elif not sc_flag:
if self.opt.task == 'pred_both':
# train generating both caption and trace
caption_outputs_both, trace_outputs_both = self.model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1],
att_masks, trace_masks, task='both')
loss_mask = ((trace_masks != 0) * (trace_feats[:, :, 4] != 1)).unsqueeze(2)
loss_both_trace = (torch.abs(trace_outputs_both[:, :, :4] - trace_feats[:, :, :4]) * loss_mask).sum() / (
loss_mask.sum() * 4)
loss_both_caption = self.crit_caption(caption_outputs_both, labels[..., 1:], masks[..., 1:])
loss_both = loss_both_caption + loss_both_trace # for baseline training
if self.opt.task in ['caption', 'c_joint_t']:
# for caption generation
caption_outputs = self.model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1],
att_masks, trace_masks, task='caption')
loss_caption = self.crit_caption(caption_outputs, labels[..., 1:], masks[..., 1:])
if self.opt.task in ['trace', 'c_joint_t']:
# for trace generation - regression
trace_outputs = self.model(fc_feats, att_feats, trace_feats, box_feats, labels[..., :-1],
att_masks, trace_masks, task='trace')
loss_mask = ((trace_masks!=0) * (trace_feats[:,:,4]!=1)).unsqueeze(2) # for those words without labels ([0,0,1,1,1]), don't calculate the loss
loss_trace = (torch.abs(trace_outputs[:,:,:4] - trace_feats[:,:,:4]) * loss_mask).sum() / (loss_mask.sum() * 4)
# # for cycle trace and caption
# trace_outputs_both = trace_outputs_both.detach()
# caption_outputs_cycle = self.model(fc_feats, att_feats, trace_outputs_both, box_feats, labels[..., :-1],
# att_masks, trace_masks, task='caption')
# caption_outputs_cycle_1 = torch.exp(caption_outputs) # get the logits before log (only after softmax)
# trace_outputs_cycle_1 = self.model(fc_feats, att_feats, trace_feats, box_feats, caption_outputs_cycle_1,
# att_masks, trace_masks, task='cycle_trace')
# loss_cycle_trace = (torch.abs(trace_outputs_cycle_1[:,:,:4] - trace_feats[:,:,:4]) * loss_mask).sum() / (loss_mask.sum() * 4)
#
# trace_outputs_cycle_2 = trace_outputs
# caption_outputs_cycle_2 = self.model(fc_feats, att_feats, trace_outputs_cycle_2, box_feats, labels[..., :-1],
# att_masks, trace_masks, task='caption')
# loss_cycle_caption = self.crit_caption(caption_outputs_cycle_2, labels[..., 1:], masks[..., 1:])
################ random permute cycle loss ###################
### random permute trace within its segments
# permute_trace_list = []
# for i in range(trace_feats.shape[0]):
# tmp_gt_length = trace_masks[i].sum().long().item()
# tmp_trace = trace_feats[i, :tmp_gt_length]
# segment_list = []
# tmp_const = np.ceil(tmp_gt_length / 5).astype(int)
# for j in range(5):
# segment_list.append(tmp_trace[j * tmp_const: (j + 1) * tmp_const])
# random.shuffle(segment_list)
# tmp_permute_trace = torch.cat(segment_list, 0)
# if tmp_permute_trace.shape[0] < trace_masks.shape[1]:
# tmp_permute_trace = torch.cat([tmp_permute_trace,
# torch.zeros([trace_masks.shape[1]-tmp_permute_trace.shape[0], tmp_permute_trace.shape[1]]).to(trace_masks.device)])
# permute_trace_list.append(tmp_permute_trace)
# permute_trace_feats = torch.stack(permute_trace_list, 0)
#
if self.opt.task == 'c_joint_t':
#### random exchange trace within batch
random_idx = np.arange(trace_feats.shape[0])
np.random.shuffle(random_idx)
rnd_trace_feats = trace_feats[random_idx]
# construct the loss
rnd_caption_outputs = self.model(fc_feats, att_feats, rnd_trace_feats, box_feats, labels[..., :-1],
att_masks, trace_masks, task='caption')
caption_outputs_cycle_1 = torch.exp(rnd_caption_outputs)
## caption_outputs_cycle_1 = torch.exp(caption_outputs) # get the logits before log (only after softmax)
trace_outputs_cycle_1 = self.model(fc_feats, att_feats, trace_feats, box_feats, caption_outputs_cycle_1,
att_masks, trace_masks, task='cycle_trace')
loss_cycle_trace = (torch.abs(
trace_outputs_cycle_1[:, :, :4] - trace_feats[:, :, :4]) * loss_mask).sum() / (loss_mask.sum() * 4)
if self.opt.task == 'pred_both':
loss = loss_both
elif self.opt.task == 'caption':
loss = loss_caption
elif self.opt.task == 'caption':
loss = loss_trace
elif self.opt.task == 'c_joint_t':
loss = loss_trace + 0.3 * (loss_caption) + 0.1 * (loss_cycle_trace)
else:
self.model.eval()
with torch.no_grad():
greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
mode='sample',
opt={'sample_method': opt.sc_sample_method,
'beam_size': opt.sc_beam_size})
self.model.train()
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
opt={'sample_method':opt.train_sample_method,
'beam_size':opt.train_beam_size,
'sample_n': opt.train_sample_n},
mode='sample')
gts = [gts[_] for _ in gt_indices.tolist()]
reward = get_self_critical_reward(greedy_res, gts, gen_result, self.opt)
reward = torch.from_numpy(reward).to(sample_logprobs)
loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
out['reward'] = reward[:,0].mean()
out['loss'] = loss
return out