def forward()

in src/model.py [0:0]


    def forward(self, img_inputs, label_target=None, maxnumlabels=0, keep_cnn_gradients=False, compute_losses=False, compute_predictions=False):

        losses = {}
        predictions = None

        assert (label_target is not None and compute_losses) or (label_target is None and not compute_losses)

        if not compute_losses and not compute_predictions:
            return losses, predictions

        # encode image
        img_features = self.image_encoder(img_inputs, keep_cnn_gradients)

        if self.decoder_ff:
            # use ff decoder to predict set of labels and cardinality
            label_logits, cardinality_logits = self.decoder(img_features)

            if compute_losses:
                # label target to k_hot
                target_k_hot = label2_k_hots(label_target, self.pad_value)
                # cardinality target
                cardinality_target = target_k_hot.sum(dim=-1).unsqueeze(1)

                # compute labels loss
                losses['label_loss'] = self.crit(label_logits, target_k_hot)

                # compute cardinality loss if needed
                if self.crit_cardinality is not None:
                    # subtract 1 from num_target to match class idxs (1st label corresponds to class 0) only
                    # 1st label corresponds to 0 only if use_empty_set is false
                    # otherwise, 1st label corresponds to 1
                    offset = 0 if self.use_empty_set else 1
                    losses['cardinality_loss'] = self.crit_cardinality(
                        cardinality_logits, (cardinality_target.squeeze() - offset).long())

            if compute_predictions:
                # consider cardinality
                if self.card_type == 'dc' and self.loss_label == 'bce':
                    offset = 0 if self.use_empty_set else 1
                    cardinality = torch.log(DC(cardinality_logits, dataset=self.dataset))
                    u_term = np.array(list(range(cardinality.size(-1)))) + offset
                    u_term = u_term * self.u_term
                    u_term = torch.from_numpy(u_term).to(device).unsqueeze(0).float()
                    cardinality = cardinality + u_term
                elif self.card_type == 'cat':
                    cardinality = torch.nn.functional.log_softmax(cardinality_logits + self.eps, dim=-1)
                else:
                    cardinality = None

                # apply nonlinearity to label logits
                if self.loss_label == 'td':
                    label_probs = nn.functional.softmax(label_logits, dim=-1)
                else:
                    label_probs = torch.sigmoid(label_logits)

                # get label ids
                predictions = predictions_to_idxs(
                    label_probs,
                    maxnumlabels,
                    self.pad_value,
                    th=self.th,
                    cardinality_prediction=cardinality,
                    which_loss=self.loss_label,
                    accumulate_probs=self.card_type == 'dc' and self.loss_label == 'bce',
                    use_empty_set=self.use_empty_set)

        else:  # auto-regressive models

            # use auto-regressive decoder to predict labels (sample function)
            # output label_logits is only used to compute losses in case of self.perminv (no teacher forcing)
            # predictions output is used for all auto-regressive models
            predictions, label_logits = self.decoder.sample(
                img_features,
                None,
                first_token_value=0,
                replacement=self.replacement)

            if compute_predictions:
                # mask labels after finding eos (cardinality)
                sample_mask = mask_from_eos(predictions, eos_value=0, mult_before=False)
                predictions[sample_mask == 0] = self.pad_value
            else:
                predictions = None

            if compute_losses:
                # add dummy first word to sequence and remove last
                first_word = torch.zeros(label_target.size(0))
                shift_target = torch.cat([first_word.unsqueeze(-1).to(device).long(), label_target],
                                         -1)[:, :-1]
                if self.perminv:
                    # autoregressive mode for decoder when training with permutation invariant objective
                    # e.g. lstmset and tfset

                    # apply softmax nonlinearity before pooling across timesteps
                    label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

                    # find idxs for eos label
                    # eos probability is the one assigned to the first position of the softmax
                    # this is used with bce loss only
                    eos = label_probs[:, :, 0]
                    eos_pos = (label_target == 0)  # all zeros except position where eos is in the gt
                    eos_head = ((label_target != self.pad_value) & (label_target != 0))  # 1s for gt label positions, 0s starting from eos position in the gt
                    eos_target = ~eos_head  # 0s for gt label positions, 1s starting from eos position in the gt

                    # select transformer steps to pool (steps corresponding to set elements, i.e. labels)
                    label_probs = label_probs * eos_head.float().unsqueeze(-1)

                    # pool
                    label_probs, _ = torch.max(label_probs, dim=1)

                    # compute label loss
                    target_k_hot = label2_k_hots(label_target, self.pad_value, remove_eos=True)
                    loss = self.crit(label_probs[:, 1:], target_k_hot)
                    losses['label_loss'] = loss

                    # compute eos loss
                    eos_loss = self.crit_eos(eos, eos_target.float())
                    # eos loss is computed for all timesteps <= eos in gt and
                    # equally penalizes the head (all 0s) and the true eos position (1)
                    losses['eos_loss'] = 0.5 * (eos_loss * eos_pos.float()).sum(1) + \
                                    0.5 * (eos_loss * eos_head.float()).sum(1) / (
                                            eos_head.float().sum(1) + self.eps)

                else:
                    # other autoregressive models
                    # we need to recompute logits using teacher forcing (forward pass)
                    label_logits, _ = self.decoder(img_features, None, shift_target)
                    label_logits_v = label_logits.view(label_logits.size(0) * label_logits.size(1), -1)

                    # compute label loss
                    label_target_v = label_target.view(-1)
                    loss = self.crit(label_logits_v, label_target_v)
                    losses['label_loss'] = loss

        return losses, predictions