def run_prediction()

in src/biencoder_predict_qa.py [0:0]


def run_prediction(roberta, heads, examples, args):
    C_MAX = args.context_max_len
    STRIDE = args.context_stride 
    assert C_MAX % 2 == 0
    assert STRIDE % 2 == 0
    # Run prediction in batch
    pred_json = {}
    i = 0
    with tqdm(total=len(examples)) as pbar:
        while i < len(examples):
            # Greedily grab the next batch until batch_size and max_tokens are full
            cur_batch = []
            cur_tokens = 0
            qid_to_batch_index = collections.defaultdict(list)
            qid_to_full_bpe = {}
            while i < len(examples) and len(cur_batch) < args.batch_size:
                next_example = examples[i]
                qid, p_bpe, q_bpe = examples[i]
                if len(p_bpe) > C_MAX:
                    p_bpe_list = []
                    for j in range(0, len(p_bpe) - C_MAX + STRIDE, STRIDE):
                        cur_endpoint = min(len(p_bpe), j + C_MAX)
                        p_bpe_list.append(p_bpe[j:cur_endpoint])
                    cur_tokens = max(cur_tokens, C_MAX, len(q_bpe))
                else:
                    p_bpe_list = [p_bpe]
                    cur_tokens = max(cur_tokens, len(p_bpe), len(q_bpe))
                if (len(cur_batch) + len(p_bpe_list)) * cur_tokens > args.max_tokens:
                    # Wait until next batch to process this
                    break
                # Add to the batch now
                for p_bpe_chunk in p_bpe_list:
                    qid_to_batch_index[qid].append(len(cur_batch))
                    cur_batch.append((qid, p_bpe_chunk, q_bpe))
                qid_to_full_bpe[qid] = p_bpe
                i += 1

            # Make prediction
            with torch.no_grad():
                p_tokens = collate_tokens([b[1] for b in cur_batch], pad_idx=1)
                q_tokens = collate_tokens([b[2] for b in cur_batch], pad_idx=1)
                p_features = roberta.model(p_tokens.to(device=roberta.device),
                                           features_only=True,
                                           return_all_hiddens=False)[0]  # B, L, d
                q_features = roberta.model(q_tokens.to(device=roberta.device),
                                           features_only=True,
                                           return_all_hiddens=False)[0]  # B, L, d
                q_cls = q_features[:,0,:]  # B, d
                start_vec = heads['start'](q_cls)  # B, d
                end_vec = heads['end'](q_cls)  # B, d
                start_logits = torch.matmul(p_features, start_vec.unsqueeze(-1))[:,:,0]
                end_logits = torch.matmul(p_features, end_vec.unsqueeze(-1))[:,:,0]

                for qid, batch_idxs in qid_to_batch_index.items():
                    if len(batch_idxs) == 1:
                        # Easy case
                        cur_start_logits = start_logits[batch_idxs[0],:]
                        cur_end_logits = end_logits[batch_idxs[0],:]
                    else:
                        cur_start_list = []
                        cur_end_list = []
                        for j in batch_idxs:
                            if j == batch_idxs[0]: # Case 1: first index
                                cur_start_list.append(start_logits[j,:C_MAX//2 + STRIDE//2])
                                cur_end_list.append(end_logits[j,:C_MAX//2 + STRIDE//2])
                            elif j == batch_idxs[-1]: # Case 2: last index (will truncate later)
                                cur_start_list.append(start_logits[j,C_MAX//2 - STRIDE//2:])
                                cur_end_list.append(end_logits[j,C_MAX//2 - STRIDE//2:])
                            else: # Case 3: middle index
                                cur_start_list.append(start_logits[j, C_MAX//2 - STRIDE//2: C_MAX//2 + STRIDE//2])
                                cur_end_list.append(end_logits[j, C_MAX//2 - STRIDE//2: C_MAX//2 + STRIDE//2])
                        cur_start_logits = torch.cat(cur_start_list)[:len(qid_to_full_bpe[qid])]
                        cur_end_logits = torch.cat(cur_end_list)[:len(qid_to_full_bpe[qid])]
                    start_probs = F.softmax(cur_start_logits, dim=0).tolist()
                    end_probs = F.softmax(cur_end_logits, dim=0).tolist()
                    best_start, best_end = find_span(start_probs, end_probs)
                    pred_str = roberta.decode(qid_to_full_bpe[qid][best_start:best_end + 1])
                    pred_json[qid] = pred_str
                    pbar.update(1)

    with open(args.output_file, 'w') as f:
        json.dump(pred_json, f)