in paq/generation/answer_extractor/span2D_model.py [0:0]
def postprocess_span2d_output(span2D_output: AnswerSpanExtractor2DModelOutput, features,
max_answer_length, passage: str, n_best_size:int) -> List[Dict]:
all_span_logits = span2D_output.span_logits.detach().cpu().numpy()
all_span_masks = span2D_output.span_masks.detach().cpu().numpy()
prelim_predictions = []
# Looping through all the features associated to the current example.
for feature_index in range(len(all_span_logits)):
# We grab the predictions of the model for this feature.
span_logits = all_span_logits[feature_index]
span_masks = all_span_masks[feature_index]
span_logits += -100 * (1 - span_masks) # mask the span logits
# This is what will allow us to map some the positions in our logits to span of texts in the original
# context.
offset_mapping = features["offset_mapping"][feature_index]
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
# available in the current feature.
token_is_max_context = None
# Update minimum null prediction.
feature_null_score = span_logits[0, 0]
min_null_prediction = {"offsets": (0, 0), "score": feature_null_score}
# Go through all possibilities for the `n_best_size` greater start and end logits.
# start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
# end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()
start_indexes, end_indexes = np.unravel_index(
np.argsort(span_logits, axis=None)[-1:-n_best_size - 10:-1], # a buffer of 10 in case some are invalid
span_logits.shape
)
start_indexes, end_indexes = start_indexes.tolist(), end_indexes.tolist()
for start_index, end_index in zip(start_indexes, end_indexes):
# Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
# to part of the input_ids that are not in the context.
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or offset_mapping[end_index] is None
):
continue
# Don't consider answers with a length that is either < 0 or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
# Don't consider answer that don't have the maximum context available (if such information is
# provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue
prelim_predictions.append(
{
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
"score": span_logits[start_index, end_index],
}
)
# Only keep the best `n_best_size` predictions.
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Use the offsets to gather the answer text in the original context.
for pred in predictions:
offsets = pred.pop("offsets")
pred["text"] = passage[offsets[0]: offsets[1]]
pred["start"] = offsets[0]
pred["end"] = offsets[1]
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
predictions.insert(0, {"text": "null", "score": -100.0, "start": 0, "end": 0})
# Include the probabilities in our predictions.
for pred in predictions:
score = pred.get("score")
pred["score"] = sigmoid(score)
return predictions