in problem.py [0:0]
def decode(self, model_output, lengths=None, batch_data=None):
""" decode the model output, either a batch of output or a single output
Args:
model_output: target indices.
if is 1d array, it is an output of a sample;
if is 2d array, it is outputs of a batch of samples;
lengths: if not None, the shape of length should be consistent with model_output.
Returns:
the original output
"""
if ProblemTypes[self.problem_type] == ProblemTypes.classification:
if isinstance(model_output, int): # output of a sample
return self.output_dict.cell(model_output)
else: # output of a batch
return self.output_dict.decode(model_output)
elif ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
if isinstance(model_output, dict):
model_output = list(model_output.values())[0]
if not isinstance(model_output, np.ndarray):
model_output = np.array(model_output)
if len(model_output.shape) == 1: # output of a sample
if lengths is None:
outputs = np.array(self.output_dict.decode(model_output))
else:
outputs = np.array(self.output_dict.decode(model_output[:lengths]))
if self.with_bos_eos:
outputs = outputs[1:-1]
elif len(model_output.shape) == 2: # output of a batch of sequence
outputs = []
if lengths is None:
for sample in model_output:
if self.with_bos_eos:
outputs.append(self.output_dict.decode(sample[1:-1]))
else:
outputs.append(self.output_dict.decode(sample))
else:
for sample, length in zip(model_output, lengths):
if self.with_bos_eos:
outputs.append(self.output_dict.decode(sample[:length][1:-1]))
else:
outputs.append(self.output_dict.decode(sample[:length]))
return outputs
elif ProblemTypes[self.problem_type] == ProblemTypes.mrc:
# for mrc, model_output is dict
answers = []
p1, p2 = list(model_output.values())[0], list(model_output.values())[1]
batch_size, c_len = p1.size()
passage_length = lengths.numpy()
padding_mask = np.ones((batch_size, c_len))
for i, single_len in enumerate(passage_length):
padding_mask[i][:single_len] = 0
device = p1.device
padding_mask = torch.from_numpy(padding_mask).byte().to(device)
p1.data.masked_fill_(padding_mask.data, float('-inf'))
p2.data.masked_fill_(padding_mask.data, float('-inf'))
ls = nn.LogSoftmax(dim=1)
mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
score, s_idx = score.max(dim=1)
score, e_idx = score.max(dim=1)
s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
# encode mrc answer text
passage_text = 'extra_passage_text'
passage_token_offsets = 'extra_passage_token_offsets'
for i in range(batch_size):
char_s_idx, _ = batch_data[passage_token_offsets][i][s_idx[i]]
_, char_e_idx = batch_data[passage_token_offsets][i][e_idx[i]]
answer = batch_data[passage_text][i][char_s_idx:char_e_idx]
answers.append(answer)
return answers