in src/run_sentiment.py [0:0]
def main(args):
random.seed(0)
# Load model, including start/end heads
roberta = RobertaModel.from_pretrained(args.load_dir, checkpoint_file='model.pt')
roberta.to('cuda')
roberta.eval()
if args.method == 'qa':
heads = {}
for name in ['start', 'end']:
head = MockClassificationHead.load_from_file(
os.path.join(args.load_dir, f'model_qa_head_{name}.pt'),
do_token_level=True)
head.to('cuda')
heads[name] = head
print('Finished loading model.')
# Read data
data = read_data(args.dataset, roberta, num_examples=args.num_examples)
batch_data = make_batches(data, args.batch_size)
print(f'Loaded {len(data)} examples ({len(batch_data)} batches).')
with torch.no_grad():
# Precompute prompt representations
calib_scores = []
if args.method == 'qa':
prompts = QA_PROMPTS[args.dataset][args.prompt_index]
prompt_vecs = []
for q, y in prompts:
x = roberta.encode(q).unsqueeze(0) # 1, L
feats = roberta.model(x.to(roberta.device),
features_only=True,
return_all_hiddens=False)[0] # 1, L, d
cls_feats = feats[0,0,:] # d
start_vec = heads['start'](cls_feats) # d
end_vec = heads['end'](cls_feats) # d
prompt_vec = torch.stack([start_vec, end_vec], dim=1)
prompt_vecs.append(prompt_vec) # d, 2
# Contextual calibration
cur_calib_scores = []
for x_calib in CALIBRATION_EXAMPLES:
toks_calib = roberta.encode(x_calib)
feats, _ = roberta.model(toks_calib.to(roberta.device).unsqueeze(0),
features_only=True,
return_all_hiddens=False) # 1, L, d
word_scores = torch.matmul(feats[0,:,:], prompt_vec) # L, 2
cur_calib_scores.append(torch.max(word_scores, dim=0)[0].sum().item())
calib_scores.append(sum(cur_calib_scores) / len(cur_calib_scores))
print('calibration:', calib_scores)
else:
prompt_start, prompt_end, prompt_options = MLM_PROMPTS[args.dataset]
prompt_toks = roberta.task.source_dictionary.encode_line(
roberta.bpe.encode(prompt_start) + ' <mask> ' + roberta.bpe.encode(prompt_end))
prompt_option_indices = [roberta.encode(x)[1] for x in prompt_options]
# Contextual calibration
if args.calibrate_lmbff:
cur_calib_scores = [[], []]
for x_calib in CALIBRATION_EXAMPLES:
toks_calib = torch.cat([roberta.encode(x_calib)[:-1], prompt_toks])
feats, _ = roberta.model(toks_calib.to(roberta.device).unsqueeze(0))
mask_idx = (toks_calib == roberta.task.mask_idx).nonzero(as_tuple=False)
logits = feats[0,mask_idx,:].squeeze()
for y, prompt_option_idx in enumerate(prompt_option_indices):
cur_calib_scores[y].append(logits[prompt_option_idx].item())
calib_scores = [sum(x) / len(x) for x in cur_calib_scores]
else:
calib_scores = [0, 0]
print(f'prompt_toks={prompt_toks}, options={prompt_option_indices}, calibration={calib_scores}')
print('Preprocessed prompts.')
# Score predictions
gold_labels = []
pred_labels = []
pred_scores = []
for batch in tqdm(batch_data):
cur_pred_scores = [{} for b in batch]
explanations = [{} for b in batch]
if args.method == 'qa':
x_batch = collate_tokens([x for (x, y) in batch], pad_idx=PAD_TOKEN)
feats = roberta.model(x_batch.to(roberta.device),
features_only=True,
return_all_hiddens=False)[0] # B, L, d
for (q, y), prompt_vec, calib_score in zip(prompts, prompt_vecs, calib_scores):
prompt_mat = prompt_vec.unsqueeze(0).expand(len(batch), -1, -1) # B, d, 2
word_scores = torch.matmul(feats, prompt_mat) # B, L, 2
# Max across words, then sum across start + end vectors
agg_scores = torch.max(word_scores, dim=1)[0].sum(dim=1).tolist() # B
for i in range(len(batch)):
cur_pred_scores[i][y] = agg_scores[i] - calib_score
if args.verbose:
start_idx, end_idx = find_span(word_scores[i,:,0].softmax(dim=0),
word_scores[i,:,1].softmax(dim=0),
max_ans_len=5) # Find a short rationale
try:
explanations[i][y] = roberta.decode(batch[i][0][start_idx:end_idx+1])
except IndexError:
explanations[i][y] = ''
else: # args.method == 'mlm'
xs_with_prompt = [torch.cat([x[:-1], prompt_toks]) for (x, y) in batch] # Strip the old EOS
x_batch = collate_tokens(xs_with_prompt, pad_idx=PAD_TOKEN)
feats, _ = roberta.model(x_batch.to(roberta.device)) # B, L, V
for i, x_with_prompt in enumerate(xs_with_prompt):
mask_idx = (x_with_prompt == roberta.task.mask_idx).nonzero(as_tuple=False)
logits = feats[i,mask_idx,:].squeeze()
for y, prompt_option_idx in enumerate(prompt_option_indices):
cur_pred_scores[i][y] = logits[prompt_option_idx].item() - calib_scores[y]
for i, (x, y) in enumerate(batch):
gold_labels.append(y)
pred_scores.append(cur_pred_scores[i][1] - cur_pred_scores[i][0])
y_pred, max_score = max(cur_pred_scores[i].items(), key=lambda p: p[1])
pred_labels.append(y_pred)
if args.verbose:
log_obj = {
'x': roberta.decode(x),
'y': y,
'pred': y_pred,
'scores': cur_pred_scores[i],
'explanation': explanations[i][y] if explanations[i] else ''
}
print(json.dumps(log_obj))
# Print stats
num_correct = sum(1 for y, pred in zip(gold_labels, pred_labels) if y == pred)
print(f'Accuracy: {num_correct}/{len(gold_labels)} = {100 * num_correct / len(gold_labels):.2f}%')
fp = sum(1 for y, pred in zip(gold_labels, pred_labels) if y == 0 and pred == 1)
fn = sum(1 for y, pred in zip(gold_labels, pred_labels) if y == 1 and pred == 0)
print(f'fp={fp}, fn={fn}')
print(evaluate(gold_labels, pred_scores))