def decode()

in problem.py [0:0]


    def decode(self, model_output, lengths=None, batch_data=None):
        """ decode the model output, either a batch of output or a single output

        Args:
            model_output: target indices.
                if is 1d array, it is an output of a sample;
                if is 2d array, it is outputs of a batch of samples;
            lengths: if not None, the shape of length should be consistent with model_output.

        Returns:
            the original output

        """
        if ProblemTypes[self.problem_type] == ProblemTypes.classification:
            if isinstance(model_output, int):       # output of a sample
                return self.output_dict.cell(model_output)
            else:   # output of a batch
                return self.output_dict.decode(model_output)
        elif ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
            if isinstance(model_output, dict):
                model_output = list(model_output.values())[0]
            if not isinstance(model_output, np.ndarray):
                model_output = np.array(model_output)
            if len(model_output.shape) == 1:        # output of a sample
                if lengths is None:
                    outputs = np.array(self.output_dict.decode(model_output))
                else:
                    outputs = np.array(self.output_dict.decode(model_output[:lengths]))
                if self.with_bos_eos:
                    outputs = outputs[1:-1]

            elif len(model_output.shape) == 2:      # output of a batch of sequence
                outputs = []
                if lengths is None:
                    for sample in model_output:
                        if self.with_bos_eos:
                            outputs.append(self.output_dict.decode(sample[1:-1]))
                        else:
                            outputs.append(self.output_dict.decode(sample))
                else:
                    for sample, length in zip(model_output, lengths):
                        if self.with_bos_eos:
                            outputs.append(self.output_dict.decode(sample[:length][1:-1]))
                        else:
                            outputs.append(self.output_dict.decode(sample[:length]))
            return outputs
        elif ProblemTypes[self.problem_type] == ProblemTypes.mrc:
            # for mrc, model_output is dict
            answers = []
            p1, p2 = list(model_output.values())[0], list(model_output.values())[1]
            batch_size, c_len = p1.size()
            passage_length = lengths.numpy()
            padding_mask = np.ones((batch_size, c_len))
            for i, single_len in enumerate(passage_length):
                padding_mask[i][:single_len] = 0
            device = p1.device
            padding_mask = torch.from_numpy(padding_mask).byte().to(device)
            p1.data.masked_fill_(padding_mask.data, float('-inf'))
            p2.data.masked_fill_(padding_mask.data, float('-inf'))
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            # encode mrc answer text
            passage_text = 'extra_passage_text'
            passage_token_offsets = 'extra_passage_token_offsets'
            for i in range(batch_size):
                char_s_idx, _ = batch_data[passage_token_offsets][i][s_idx[i]]
                _, char_e_idx = batch_data[passage_token_offsets][i][e_idx[i]]
                answer = batch_data[passage_text][i][char_s_idx:char_e_idx]
                answers.append(answer)
            return answers