def _convert_to_span_matrix()

in paq/generation/answer_extractor/span2D_model.py [0:0]


        def _convert_to_span_matrix(start_positions, end_positions):
            span_labels = torch.zeros_like(span_logits)  # [B, L, L]
            for i, (start_post, end_post) in enumerate(zip(start_positions, end_positions)):
                for start_idx, end_idx in zip(start_post, end_post):
                    if 0 <= start_idx and 0 <= end_idx:  # we use -1 as null indicator
                        assert start_idx < sequence_length and end_idx < sequence_length
                        span_labels[i, start_idx, end_idx] = 1.
                    else:
                        break
            return span_labels