in src/rime/models/zero_shot/bayes_lm.py [0:0]
def _compute_log_p_x_given_y(self, Y, x, device):
""" evaluate p_x_given_y for each y in Y; label = x.expand(...) """
batch_size = len(Y)
sequences = [self.prompt.format(y=y, x=x) for y in Y]
inputs = self.tokenizer(sequences, padding=True, return_tensors='pt').to(device)
targets = self.tokenizer(x, return_tensors='pt')['input_ids'].to(device)
targets = torch.vstack([targets for _ in range(batch_size)])
seq_len = inputs['attention_mask'].sum(1).tolist()
target_len = targets.shape[1]
if hasattr(self.model, "transformer"): # gpt causal lm
hidden_states = self.model.transformer(**inputs)[0]
hidden_states = torch.vstack([x[n - target_len - 1: n - 1]
for x, n in zip(hidden_states, seq_len)])
logits = self.model.lm_head(hidden_states)
elif hasattr(self.model, "bert"): # bert [CLS] sequence [SEP], performs similarly
targets = targets[:, 1:-1]
target_len = target_len - 2
hidden_states = self.model.bert(**inputs)[0]
hidden_states = torch.vstack([x[n - target_len - 1: n - 1] # [3-1-1 : 3-1]
for x, n in zip(hidden_states, seq_len)])
logits = self.model.cls(hidden_states)
else: # decoding non-target items can lead to 20% longer compute time
logits = self.model(**inputs).logits
logits = torch.vstack([x[n - target_len - 1: n - 1]
for x, n in zip(logits, seq_len)])
loss = self.loss(logits, targets.reshape(-1))
return (-loss).reshape(targets.shape).mean(1).tolist()