def sample()

in src/modules/transformer_decoder.py [0:0]


    def sample(self,
               features,
               mask,
               greedy=True,
               temperature=1.0,
               first_token_value=0,
               replacement=True):

        incremental_state = {}

        # create dummy previous word
        fs = features.size(0)
        first_word = torch.ones(fs) * first_token_value

        first_word = first_word.to(device).long()
        sampled_ids = [first_word]
        logits = []
        for i in range(self.seq_length):
            # forward
            outputs = self.forward(features, mask, torch.stack(sampled_ids, 1), incremental_state)
            outputs = outputs.squeeze(1)
            if not replacement:
                # predicted mask
                if i == 0:
                    predicted_mask = torch.zeros(outputs.shape).float().to(device)
                else:
                    batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
                    sampled_ids_new = sampled_ids[i][batch_ind]
                    predicted_mask[batch_ind, sampled_ids_new] = float('-inf')

                # mask previously selected ids
                outputs += predicted_mask

            # add outputs to list
            logits.append(outputs)

            if greedy:
                _, predicted = outputs.max(1)
                predicted = predicted.detach()
            else:
                k = 10
                prob_prev = torch.div(outputs.squeeze(1), temperature)
                prob_prev = torch.nn.functional.softmax(prob_prev, dim=-1).data

                # top k random sampling
                prob_prev_topk, indices = torch.topk(prob_prev, k=k, dim=1)
                predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
                predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()

            sampled_ids.append(predicted)
        sampled_ids = torch.stack(sampled_ids[1:], 1)
        logits = torch.stack(logits, 1)
        return sampled_ids, logits