def query()

in submix.py [0:0]


    def query(self, P):
        #L,P = self.compute_logits_at_context(context)
        if self.temp < 1.0:
            Z = lambda q : torch.sum(torch.exp(torch.log(q)/self.temp))
            P = [torch.exp(torch.log(p)/self.temp)/Z(p) for p in P]
        P = self.max_log_projection(P)
        if self.queries_remaining <= 0 or min(self.eps_remaining) <= 0:
            self.STOP = True
            P_out = P[0]
        else:                
            lambdas, mixes = self._get_pairwise_lambdas(self.pairing, P)
            P_out = sum(mixes)/len(mixes)
            epsilons = []
            l = len(P_out)
            for i,p in enumerate(mixes):
                p_prime = (P_out - p/l)*(l/(l-1))
                eps = self.renyi_dp(p_prime, P_out, alpha=self.alpha)
                eps = max(eps, self.renyi_dp(P_out, p, alpha=self.alpha))
                self.eps_remaining[i] -= eps
        self.queries_remaining -= 1
        return P_out