src/modules/rnn_decoder.py [149:175]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            outputs = outputs.squeeze(1)
            if not replacement:
                # predicted mask
                if i == 0:
                    predicted_mask = torch.zeros(outputs.shape).float().to(device)
                else:
                    batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
                    sampled_ids_new = sampled_ids[i][batch_ind]
                    predicted_mask[batch_ind, sampled_ids_new] = float('-inf')

                # mask previously selected ids
                outputs += predicted_mask

            logits.append(outputs)
            # outputs = torch.nn.functional.log_softmax(outputs, dim=1)
            if greedy:
                _, predicted = outputs.max(1)
                predicted = predicted.detach()
            else:
                k = 10
                prob_prev = torch.div(outputs.squeeze(1), temperature)
                prob_prev = torch.nn.functional.softmax(prob_prev, dim=-1).data
                # top k random sampling
                prob_prev_topk, indices = torch.topk(prob_prev, k=k, dim=1)
                predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
                predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()
            sampled_ids.append(predicted)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/modules/transformer_decoder.py [262:291]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            outputs = outputs.squeeze(1)
            if not replacement:
                # predicted mask
                if i == 0:
                    predicted_mask = torch.zeros(outputs.shape).float().to(device)
                else:
                    batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
                    sampled_ids_new = sampled_ids[i][batch_ind]
                    predicted_mask[batch_ind, sampled_ids_new] = float('-inf')

                # mask previously selected ids
                outputs += predicted_mask

            # add outputs to list
            logits.append(outputs)

            if greedy:
                _, predicted = outputs.max(1)
                predicted = predicted.detach()
            else:
                k = 10
                prob_prev = torch.div(outputs.squeeze(1), temperature)
                prob_prev = torch.nn.functional.softmax(prob_prev, dim=-1).data

                # top k random sampling
                prob_prev_topk, indices = torch.topk(prob_prev, k=k, dim=1)
                predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
                predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()

            sampled_ids.append(predicted)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



