in submix.py [0:0]
def compute_logits_at_context(self, context):
if isinstance(context,str):
context = torch.tensor(self.tokenizer.encode(x)).to(self.device)
logit = lambda model,x : model(x).logits.squeeze()
L = [logit(lm,context) for lm in self.LMs]
P = [nn.functional.softmax(logits, dim=1) for logits in L]
return L, P