def postprocess_span2d_output()

in paq/generation/answer_extractor/span2D_model.py [0:0]


def postprocess_span2d_output(span2D_output: AnswerSpanExtractor2DModelOutput, features,
        max_answer_length, passage: str, n_best_size:int) -> List[Dict]:
    all_span_logits = span2D_output.span_logits.detach().cpu().numpy()
    all_span_masks = span2D_output.span_masks.detach().cpu().numpy()

    prelim_predictions = []
    # Looping through all the features associated to the current example.
    for feature_index in range(len(all_span_logits)):
        # We grab the predictions of the model for this feature.
        span_logits = all_span_logits[feature_index]
        span_masks = all_span_masks[feature_index]
        span_logits += -100 * (1 - span_masks)  # mask the span logits

        # This is what will allow us to map some the positions in our logits to span of texts in the original
        # context.
        offset_mapping = features["offset_mapping"][feature_index]
        # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
        # available in the current feature.
        token_is_max_context = None

        # Update minimum null prediction.
        feature_null_score = span_logits[0, 0]
        min_null_prediction = {"offsets": (0, 0), "score": feature_null_score}

        # Go through all possibilities for the `n_best_size` greater start and end logits.
        # start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
        # end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()
        start_indexes, end_indexes = np.unravel_index(
            np.argsort(span_logits, axis=None)[-1:-n_best_size - 10:-1],  # a buffer of 10 in case some are invalid
            span_logits.shape
        )
        start_indexes, end_indexes = start_indexes.tolist(), end_indexes.tolist()
        for start_index, end_index in zip(start_indexes, end_indexes):
            # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
            # to part of the input_ids that are not in the context.
            if (
                start_index >= len(offset_mapping)
                or end_index >= len(offset_mapping)
                or offset_mapping[start_index] is None
                or offset_mapping[end_index] is None
            ):
                continue
            # Don't consider answers with a length that is either < 0 or > max_answer_length.
            if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                continue
            # Don't consider answer that don't have the maximum context available (if such information is
            # provided).
            if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                continue
            prelim_predictions.append(
                {
                    "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                    "score": span_logits[start_index, end_index],
                }
            )

    # Only keep the best `n_best_size` predictions.
    predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

    # Use the offsets to gather the answer text in the original context.
    for pred in predictions:
        offsets = pred.pop("offsets")
        pred["text"] = passage[offsets[0]: offsets[1]]
        pred["start"] = offsets[0]
        pred["end"] = offsets[1]

    # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
    # failure.
    if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
        predictions.insert(0, {"text": "null", "score": -100.0, "start": 0, "end": 0})

    # Include the probabilities in our predictions.
    for pred in predictions:
        score = pred.get("score")
        pred["score"] = sigmoid(score)

    return predictions