in src/model.py [0:0]
def predictions_to_idxs(label_logits,
maxnumlabels,
pad_value,
th=1,
cardinality_prediction=None,
which_loss='bce',
accumulate_probs=False,
use_empty_set=False):
assert th > 0 and th <= 1
card_offset = 0 if use_empty_set else 1
# select topk elements
probs, idxs = torch.topk(label_logits, k=maxnumlabels, dim=1, largest=True, sorted=True)
idxs_clone = idxs.clone()
# mask to identify elements within the top-maxnumlabel ones which satisfy the threshold th
if which_loss == 'td':
# cumulative threshold
mask = torch.ones(probs.size()).to(device).byte()
for idx in range(probs.size(1)):
mask_step = torch.sum(probs[:, 0:idx], dim=-1) < th
mask[:, idx] = mask[:, idx] * mask_step
else:
# probility threshold
mask = (probs > th).byte()
# if the model has cardinality prediction
if cardinality_prediction is not None:
# get the argmax for each element in the batch to get the cardinality
# (note that the output is N - 1, e.g. argmax = 0 means that there's 1 element)
# unless we are in the empty set case, e.g. argmax = 0 means there there are 0 elements
if accumulate_probs:
for c in range(cardinality_prediction.size(-1)):
value = torch.sum(torch.log(probs[:, 0:c + 1]), dim=-1)
cardinality_prediction[:, c] += value
# select cardinality
_, card_idx = torch.max(cardinality_prediction, dim=-1)
mask = torch.ones(probs.size()).to(device).byte()
aux_mask = torch.ones(mask.size(0)).to(device).byte()
for i in range(mask.size(-1)):
# If the cardinality prediction is higher than i, it means that from this point
# on the mask must be 0. Predicting 0 cardinality means 0 objects when
# use_empty_set=True and 1 object when use_empty_set=False
# real cardinality value is
above_cardinality = (i < card_idx + card_offset)
# multiply the auxiliar mask with this condition
# (once you multiply by 0, the following entries will also be 0)
aux_mask = aux_mask * above_cardinality
mask[:, i] = aux_mask
else:
if not use_empty_set:
mask[:, 0] = 1
idxs_clone[mask == 0] = pad_value
return idxs_clone