in submix.py [0:0]
def query(self, P):
#L,P = self.compute_logits_at_context(context)
if self.temp < 1.0:
Z = lambda q : torch.sum(torch.exp(torch.log(q)/self.temp))
P = [torch.exp(torch.log(p)/self.temp)/Z(p) for p in P]
P = self.max_log_projection(P)
if self.queries_remaining <= 0 or min(self.eps_remaining) <= 0:
self.STOP = True
P_out = P[0]
else:
lambdas, mixes = self._get_pairwise_lambdas(self.pairing, P)
P_out = sum(mixes)/len(mixes)
epsilons = []
l = len(P_out)
for i,p in enumerate(mixes):
p_prime = (P_out - p/l)*(l/(l-1))
eps = self.renyi_dp(p_prime, P_out, alpha=self.alpha)
eps = max(eps, self.renyi_dp(P_out, p, alpha=self.alpha))
self.eps_remaining[i] -= eps
self.queries_remaining -= 1
return P_out