in paq/generation/answer_extractor/span2D_model.py [0:0]
def _convert_to_span_matrix(start_positions, end_positions):
span_labels = torch.zeros_like(span_logits) # [B, L, L]
for i, (start_post, end_post) in enumerate(zip(start_positions, end_positions)):
for start_idx, end_idx in zip(start_post, end_post):
if 0 <= start_idx and 0 <= end_idx: # we use -1 as null indicator
assert start_idx < sequence_length and end_idx < sequence_length
span_labels[i, start_idx, end_idx] = 1.
else:
break
return span_labels