def _compute_log_p_x_given_y()

in src/rime/models/zero_shot/bayes_lm.py [0:0]


    def _compute_log_p_x_given_y(self, Y, x, device):
        """ evaluate p_x_given_y for each y in Y; label = x.expand(...) """

        batch_size = len(Y)
        sequences = [self.prompt.format(y=y, x=x) for y in Y]
        inputs = self.tokenizer(sequences, padding=True, return_tensors='pt').to(device)
        targets = self.tokenizer(x, return_tensors='pt')['input_ids'].to(device)
        targets = torch.vstack([targets for _ in range(batch_size)])

        seq_len = inputs['attention_mask'].sum(1).tolist()
        target_len = targets.shape[1]

        if hasattr(self.model, "transformer"):  # gpt causal lm
            hidden_states = self.model.transformer(**inputs)[0]
            hidden_states = torch.vstack([x[n - target_len - 1: n - 1]
                                          for x, n in zip(hidden_states, seq_len)])
            logits = self.model.lm_head(hidden_states)

        elif hasattr(self.model, "bert"):  # bert [CLS] sequence [SEP], performs similarly
            targets = targets[:, 1:-1]
            target_len = target_len - 2

            hidden_states = self.model.bert(**inputs)[0]
            hidden_states = torch.vstack([x[n - target_len - 1: n - 1]  # [3-1-1 : 3-1]
                                          for x, n in zip(hidden_states, seq_len)])
            logits = self.model.cls(hidden_states)

        else:  # decoding non-target items can lead to 20% longer compute time
            logits = self.model(**inputs).logits
            logits = torch.vstack([x[n - target_len - 1: n - 1]
                                   for x, n in zip(logits, seq_len)])

        loss = self.loss(logits, targets.reshape(-1))
        return (-loss).reshape(targets.shape).mean(1).tolist()