def sample()

in src/modules/rnn_decoder.py [0:0]


    def sample(self,
               features,
               mask,
               greedy=True,
               temperature=1.0,
               first_token_value=0,
               replacement=True):
        """Generate captions for given image features."""
        logits = []
        avg_feats = torch.mean(features, dim=-1)

        inputs = avg_feats
        states = None
        fs = features.size(0)
        prev_word = torch.ones(fs, 1).cuda().long() * first_token_value
        sampled_ids = [prev_word]
        prev_word = self.embed(prev_word).squeeze(1)
        for i in range(self.seq_length):
            v, states, att_coeffs = self.core(inputs, features, prev_word, states)
            inputs = v
            outputs = self.linear(v)

            outputs = outputs.squeeze(1)
            if not replacement:
                # predicted mask
                if i == 0:
                    predicted_mask = torch.zeros(outputs.shape).float().to(device)
                else:
                    batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
                    sampled_ids_new = sampled_ids[i][batch_ind]
                    predicted_mask[batch_ind, sampled_ids_new] = float('-inf')

                # mask previously selected ids
                outputs += predicted_mask

            logits.append(outputs)
            # outputs = torch.nn.functional.log_softmax(outputs, dim=1)
            if greedy:
                _, predicted = outputs.max(1)
                predicted = predicted.detach()
            else:
                k = 10
                prob_prev = torch.div(outputs.squeeze(1), temperature)
                prob_prev = torch.nn.functional.softmax(prob_prev, dim=-1).data
                # top k random sampling
                prob_prev_topk, indices = torch.topk(prob_prev, k=k, dim=1)
                predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
                predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()
            sampled_ids.append(predicted)
            prev_word = self.embed(predicted)

        logits = torch.stack(logits, 1)
        sampled_ids = torch.stack(sampled_ids[1:], 1)

        return sampled_ids, logits