def _sample()

in captioning/models/AttModel_for_coco_caption_task.py [0:0]


    def _sample(self, fc_feats, att_feats, trace_feats, box_feats, att_masks=None, trace_masks=None, show_gate_labels=None, task=None, opt={}):

        sample_method = opt.get('sample_method', 'greedy')
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)
        sample_n = int(opt.get('sample_n', 1))
        group_size = opt.get('group_size', 1)
        output_logsoftmax = opt.get('output_logsoftmax', 1)
        decoding_constraint = opt.get('decoding_constraint', 0)
        block_trigrams = opt.get('block_trigrams', 0)
        remove_bad_endings = opt.get('remove_bad_endings', 0)
        if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
            return self._sample_beam(fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks, show_gate_labels, task, opt)
        if group_size > 1:
            return self._diverse_sample(fc_feats, att_feats, att_masks, opt)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size*sample_n)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, trace_feats_to_decoder = \
            self._prepare_feature(fc_feats, att_feats, trace_feats, box_feats, att_masks, trace_masks)

        if sample_n > 1:
            p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
                [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
            )

        trigrams = [] # will be a list of batch_size dictionaries
        
        seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
        if task == 'both':
            tmp_trace_feats = torch.zeros([trace_feats_to_decoder.shape[0], 1, trace_feats_to_decoder.shape[2]]).to(trace_masks.device)
        for t in range(self.seq_length + 1):
            if t == 0: # input <bos>
                it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)


            if task == 'caption' or task == 'show':
                logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, trace_feats_to_decoder, trace_masks, show_gate_labels, task, state, output_logsoftmax=output_logsoftmax)
            elif task == 'both':
                logprobs, state, output_trace = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks,
                                                          tmp_trace_feats, trace_masks, show_gate_labels, task, state,
                                                          output_logsoftmax=output_logsoftmax)
                output_trace = output_trace[:, t]
                output_trace[:, 4] = (output_trace[:, 2] - output_trace[:, 0]) * (output_trace[:, 3] - output_trace[:, 1])
                tmp_trace_feats = torch.cat([tmp_trace_feats, output_trace.unsqueeze(1)], 1)

            if decoding_constraint and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
                logprobs = logprobs + tmp

            if remove_bad_endings and t > 0:
                tmp = logprobs.new_zeros(logprobs.size())
                prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
                # Make it impossible to generate bad_endings
                tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
                logprobs = logprobs + tmp

            # Mess with trigrams
            # Copy from https://github.com/lukemelas/image-paragraph-captioning
            if block_trigrams and t >= 3:
                # Store trigram generated at last step
                prev_two_batch = seq[:,t-3:t-1]
                for i in range(batch_size): # = seq.size(0)
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    current  = seq[i][t-1]
                    if t == 3: # initialize
                        trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
                    elif t > 3:
                        if prev_two in trigrams[i]: # add to list
                            trigrams[i][prev_two].append(current)
                        else: # create list
                            trigrams[i][prev_two] = [current]
                # Block used trigrams at next step
                prev_two_batch = seq[:,t-2:t]
                mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
                for i in range(batch_size):
                    prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
                    if prev_two in trigrams[i]:
                        for j in trigrams[i][prev_two]:
                            mask[i,j] += 1
                # Apply mask to log probs
                #logprobs = logprobs - (mask * 1e9)
                alpha = 2.0 # = 4
                logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)

            # sample the next word
            if t == self.seq_length: # skip if we achieve maximum length
                break
            it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)

            # stop when all finished
            if t == 0:
                unfinished = it != self.eos_idx
            else:
                it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
                logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
                unfinished = unfinished & (it != self.eos_idx)
            seq[:,t] = it
            seqLogprobs[:,t] = logprobs
            # quit loop if all sequences have finished
            if unfinished.sum() == 0:
                break

        ### for decoder evaluate: cut at the ground truth caption length. Since in controlled caption generation, we assume we know the caption length
        # if  task != 'both':
        #     for i in range(trace_masks.shape[0]):
        #         tmp_num = trace_masks[i].sum().long()
        #         seq[i, tmp_num:] = 0
        #         seqLogprobs[i, tmp_num:, :] = 0

        if task != 'both':
            return seq, seqLogprobs
        else:
            tmp_trace_feats = tmp_trace_feats[:, 1:-1]
            return seq, seqLogprobs, torch.cat([tmp_trace_feats,
                                                torch.zeros([seq.shape[0], seq.shape[1]-tmp_trace_feats.shape[1], tmp_trace_feats.shape[2]]).to(seq.device)], 1)