def main()

in src/run_sentiment.py [0:0]


def main(args):
    random.seed(0)

    # Load model, including start/end heads
    roberta = RobertaModel.from_pretrained(args.load_dir, checkpoint_file='model.pt')
    roberta.to('cuda')
    roberta.eval()
    if args.method == 'qa':
        heads = {}
        for name in ['start', 'end']:
            head = MockClassificationHead.load_from_file(
                    os.path.join(args.load_dir, f'model_qa_head_{name}.pt'),
                    do_token_level=True)
            head.to('cuda')
            heads[name] = head
    print('Finished loading model.')

    # Read data
    data = read_data(args.dataset, roberta, num_examples=args.num_examples)
    batch_data = make_batches(data, args.batch_size)
    print(f'Loaded {len(data)} examples ({len(batch_data)} batches).')

    with torch.no_grad():
        # Precompute prompt representations 
        calib_scores = []
        if args.method == 'qa':
            prompts = QA_PROMPTS[args.dataset][args.prompt_index]
            prompt_vecs = []
            for q, y in prompts:
                x = roberta.encode(q).unsqueeze(0)  # 1, L
                feats = roberta.model(x.to(roberta.device),
                                      features_only=True,
                                      return_all_hiddens=False)[0]  # 1, L, d
                cls_feats = feats[0,0,:]  # d
                start_vec = heads['start'](cls_feats)  # d
                end_vec = heads['end'](cls_feats)  # d
                prompt_vec = torch.stack([start_vec, end_vec], dim=1)
                prompt_vecs.append(prompt_vec)  # d, 2 

                # Contextual calibration
                cur_calib_scores = []
                for x_calib in CALIBRATION_EXAMPLES: 
                    toks_calib = roberta.encode(x_calib)
                    feats, _ = roberta.model(toks_calib.to(roberta.device).unsqueeze(0),
                                          features_only=True,
                                          return_all_hiddens=False)  # 1, L, d
                    word_scores = torch.matmul(feats[0,:,:], prompt_vec)  # L, 2
                    cur_calib_scores.append(torch.max(word_scores, dim=0)[0].sum().item())
                calib_scores.append(sum(cur_calib_scores) / len(cur_calib_scores))
            print('calibration:', calib_scores)

        else:
            prompt_start, prompt_end, prompt_options = MLM_PROMPTS[args.dataset]
            prompt_toks = roberta.task.source_dictionary.encode_line(
                    roberta.bpe.encode(prompt_start) + ' <mask> ' + roberta.bpe.encode(prompt_end))
            prompt_option_indices = [roberta.encode(x)[1] for x in prompt_options]

            # Contextual calibration
            if args.calibrate_lmbff:
                cur_calib_scores = [[], []]
                for x_calib in CALIBRATION_EXAMPLES:
                    toks_calib = torch.cat([roberta.encode(x_calib)[:-1], prompt_toks])
                    feats, _ = roberta.model(toks_calib.to(roberta.device).unsqueeze(0))
                    mask_idx = (toks_calib == roberta.task.mask_idx).nonzero(as_tuple=False)
                    logits = feats[0,mask_idx,:].squeeze()
                    for y, prompt_option_idx in enumerate(prompt_option_indices):
                        cur_calib_scores[y].append(logits[prompt_option_idx].item())
                calib_scores = [sum(x) / len(x) for x in cur_calib_scores]
            else:
                calib_scores = [0, 0]

            print(f'prompt_toks={prompt_toks}, options={prompt_option_indices}, calibration={calib_scores}')

        print('Preprocessed prompts.')

        # Score predictions
        gold_labels = []
        pred_labels = []
        pred_scores = []
        for batch in tqdm(batch_data):
            cur_pred_scores = [{} for b in batch]
            explanations = [{} for b in batch]
            if args.method == 'qa':
                x_batch = collate_tokens([x for (x, y) in batch], pad_idx=PAD_TOKEN)
                feats = roberta.model(x_batch.to(roberta.device),
                                      features_only=True,
                                      return_all_hiddens=False)[0]  # B, L, d
                for (q, y), prompt_vec, calib_score in zip(prompts, prompt_vecs, calib_scores):
                    prompt_mat = prompt_vec.unsqueeze(0).expand(len(batch), -1, -1)  # B, d, 2
                    word_scores = torch.matmul(feats, prompt_mat)  # B, L, 2
                    # Max across words, then sum across start + end vectors
                    agg_scores = torch.max(word_scores, dim=1)[0].sum(dim=1).tolist()  # B
                    for i in range(len(batch)):
                        cur_pred_scores[i][y] = agg_scores[i] - calib_score
                        if args.verbose:
                            start_idx, end_idx = find_span(word_scores[i,:,0].softmax(dim=0),
                                                           word_scores[i,:,1].softmax(dim=0),
                                                           max_ans_len=5)  # Find a short rationale
                            try:
                                explanations[i][y] = roberta.decode(batch[i][0][start_idx:end_idx+1])
                            except IndexError:
                                explanations[i][y] = ''
            else:  # args.method == 'mlm'
                xs_with_prompt = [torch.cat([x[:-1], prompt_toks]) for (x, y) in batch]  # Strip the old EOS
                x_batch = collate_tokens(xs_with_prompt, pad_idx=PAD_TOKEN)
                feats, _ = roberta.model(x_batch.to(roberta.device))  # B, L, V
                for i, x_with_prompt in enumerate(xs_with_prompt):
                    mask_idx = (x_with_prompt == roberta.task.mask_idx).nonzero(as_tuple=False)
                    logits = feats[i,mask_idx,:].squeeze()
                    for y, prompt_option_idx in enumerate(prompt_option_indices):
                        cur_pred_scores[i][y] = logits[prompt_option_idx].item() - calib_scores[y]
            for i, (x, y) in enumerate(batch):
                gold_labels.append(y)
                pred_scores.append(cur_pred_scores[i][1] - cur_pred_scores[i][0])
                y_pred, max_score = max(cur_pred_scores[i].items(), key=lambda p: p[1])
                pred_labels.append(y_pred)
                if args.verbose:
                    log_obj = {
                            'x': roberta.decode(x),
                            'y': y,
                            'pred': y_pred,
                            'scores': cur_pred_scores[i],
                            'explanation': explanations[i][y] if explanations[i] else ''
                    }
                    print(json.dumps(log_obj))

    # Print stats
    num_correct = sum(1 for y, pred in zip(gold_labels, pred_labels) if y == pred)
    print(f'Accuracy: {num_correct}/{len(gold_labels)} = {100 * num_correct / len(gold_labels):.2f}%')
    fp = sum(1 for y, pred in zip(gold_labels, pred_labels) if y == 0 and pred == 1)
    fn = sum(1 for y, pred in zip(gold_labels, pred_labels) if y == 1 and pred == 0)
    print(f'fp={fp}, fn={fn}')
    print(evaluate(gold_labels, pred_scores))