in src/model.py [0:0]
def forward(self, img_inputs, label_target=None, maxnumlabels=0, keep_cnn_gradients=False, compute_losses=False, compute_predictions=False):
losses = {}
predictions = None
assert (label_target is not None and compute_losses) or (label_target is None and not compute_losses)
if not compute_losses and not compute_predictions:
return losses, predictions
# encode image
img_features = self.image_encoder(img_inputs, keep_cnn_gradients)
if self.decoder_ff:
# use ff decoder to predict set of labels and cardinality
label_logits, cardinality_logits = self.decoder(img_features)
if compute_losses:
# label target to k_hot
target_k_hot = label2_k_hots(label_target, self.pad_value)
# cardinality target
cardinality_target = target_k_hot.sum(dim=-1).unsqueeze(1)
# compute labels loss
losses['label_loss'] = self.crit(label_logits, target_k_hot)
# compute cardinality loss if needed
if self.crit_cardinality is not None:
# subtract 1 from num_target to match class idxs (1st label corresponds to class 0) only
# 1st label corresponds to 0 only if use_empty_set is false
# otherwise, 1st label corresponds to 1
offset = 0 if self.use_empty_set else 1
losses['cardinality_loss'] = self.crit_cardinality(
cardinality_logits, (cardinality_target.squeeze() - offset).long())
if compute_predictions:
# consider cardinality
if self.card_type == 'dc' and self.loss_label == 'bce':
offset = 0 if self.use_empty_set else 1
cardinality = torch.log(DC(cardinality_logits, dataset=self.dataset))
u_term = np.array(list(range(cardinality.size(-1)))) + offset
u_term = u_term * self.u_term
u_term = torch.from_numpy(u_term).to(device).unsqueeze(0).float()
cardinality = cardinality + u_term
elif self.card_type == 'cat':
cardinality = torch.nn.functional.log_softmax(cardinality_logits + self.eps, dim=-1)
else:
cardinality = None
# apply nonlinearity to label logits
if self.loss_label == 'td':
label_probs = nn.functional.softmax(label_logits, dim=-1)
else:
label_probs = torch.sigmoid(label_logits)
# get label ids
predictions = predictions_to_idxs(
label_probs,
maxnumlabels,
self.pad_value,
th=self.th,
cardinality_prediction=cardinality,
which_loss=self.loss_label,
accumulate_probs=self.card_type == 'dc' and self.loss_label == 'bce',
use_empty_set=self.use_empty_set)
else: # auto-regressive models
# use auto-regressive decoder to predict labels (sample function)
# output label_logits is only used to compute losses in case of self.perminv (no teacher forcing)
# predictions output is used for all auto-regressive models
predictions, label_logits = self.decoder.sample(
img_features,
None,
first_token_value=0,
replacement=self.replacement)
if compute_predictions:
# mask labels after finding eos (cardinality)
sample_mask = mask_from_eos(predictions, eos_value=0, mult_before=False)
predictions[sample_mask == 0] = self.pad_value
else:
predictions = None
if compute_losses:
# add dummy first word to sequence and remove last
first_word = torch.zeros(label_target.size(0))
shift_target = torch.cat([first_word.unsqueeze(-1).to(device).long(), label_target],
-1)[:, :-1]
if self.perminv:
# autoregressive mode for decoder when training with permutation invariant objective
# e.g. lstmset and tfset
# apply softmax nonlinearity before pooling across timesteps
label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
# find idxs for eos label
# eos probability is the one assigned to the first position of the softmax
# this is used with bce loss only
eos = label_probs[:, :, 0]
eos_pos = (label_target == 0) # all zeros except position where eos is in the gt
eos_head = ((label_target != self.pad_value) & (label_target != 0)) # 1s for gt label positions, 0s starting from eos position in the gt
eos_target = ~eos_head # 0s for gt label positions, 1s starting from eos position in the gt
# select transformer steps to pool (steps corresponding to set elements, i.e. labels)
label_probs = label_probs * eos_head.float().unsqueeze(-1)
# pool
label_probs, _ = torch.max(label_probs, dim=1)
# compute label loss
target_k_hot = label2_k_hots(label_target, self.pad_value, remove_eos=True)
loss = self.crit(label_probs[:, 1:], target_k_hot)
losses['label_loss'] = loss
# compute eos loss
eos_loss = self.crit_eos(eos, eos_target.float())
# eos loss is computed for all timesteps <= eos in gt and
# equally penalizes the head (all 0s) and the true eos position (1)
losses['eos_loss'] = 0.5 * (eos_loss * eos_pos.float()).sum(1) + \
0.5 * (eos_loss * eos_head.float()).sum(1) / (
eos_head.float().sum(1) + self.eps)
else:
# other autoregressive models
# we need to recompute logits using teacher forcing (forward pass)
label_logits, _ = self.decoder(img_features, None, shift_target)
label_logits_v = label_logits.view(label_logits.size(0) * label_logits.size(1), -1)
# compute label loss
label_target_v = label_target.view(-1)
loss = self.crit(label_logits_v, label_target_v)
losses['label_loss'] = loss
return losses, predictions