def compute_logits_at_context()

in submix.py [0:0]


    def compute_logits_at_context(self, context):
        if isinstance(context,str):
            context = torch.tensor(self.tokenizer.encode(x)).to(self.device)
        logit =  lambda model,x : model(x).logits.squeeze()
        L = [logit(lm,context) for lm in self.LMs]
        P = [nn.functional.softmax(logits, dim=1) for logits in L]
        return L, P