def predictions_to_idxs()

in src/model.py [0:0]


def predictions_to_idxs(label_logits,
                    maxnumlabels,
                    pad_value,
                    th=1,
                    cardinality_prediction=None,
                    which_loss='bce',
                    accumulate_probs=False,
                    use_empty_set=False):

    assert th > 0 and th <= 1


    card_offset = 0 if use_empty_set else 1

    # select topk elements
    probs, idxs = torch.topk(label_logits, k=maxnumlabels, dim=1, largest=True, sorted=True)
    idxs_clone = idxs.clone()

    # mask to identify elements within the top-maxnumlabel ones which satisfy the threshold th
    if which_loss == 'td':
        # cumulative threshold
        mask = torch.ones(probs.size()).to(device).byte()
        for idx in range(probs.size(1)):
            mask_step = torch.sum(probs[:, 0:idx], dim=-1) < th
            mask[:, idx] = mask[:, idx] * mask_step
    else:
        # probility threshold
        mask = (probs > th).byte()

    # if the model has cardinality prediction
    if cardinality_prediction is not None:

        # get the argmax for each element in the batch to get the cardinality
        # (note that the output is N - 1, e.g. argmax = 0 means that there's 1 element)
        # unless we are in the empty set case, e.g. argmax = 0 means there there are 0 elements

        if accumulate_probs:
            for c in range(cardinality_prediction.size(-1)):
                value = torch.sum(torch.log(probs[:, 0:c + 1]), dim=-1)
                cardinality_prediction[:, c] += value

        # select cardinality
        _, card_idx = torch.max(cardinality_prediction, dim=-1)

        mask = torch.ones(probs.size()).to(device).byte()
        aux_mask = torch.ones(mask.size(0)).to(device).byte()

        for i in range(mask.size(-1)):
            # If the cardinality prediction is higher than i, it means that from this point
            # on the mask must be 0. Predicting 0 cardinality means 0 objects when
            # use_empty_set=True and 1 object when use_empty_set=False
            # real cardinality value is
            above_cardinality = (i < card_idx + card_offset)
            # multiply the auxiliar mask with this condition
            # (once you multiply by 0, the following entries will also be 0)
            aux_mask = aux_mask * above_cardinality
            mask[:, i] = aux_mask
    else:
        if not use_empty_set:
            mask[:, 0] = 1

    idxs_clone[mask == 0] = pad_value

    return idxs_clone