in submix.py [0:0]
def max_log_projection(self, P):
#project all P into max-log ball around P0
i = 1
Q = [ P[0] ]
for _ in range(len(P)-1):
P_prime = torch.clone(P[i])
max_log_dist = lambda P,Q : torch.max(torch.abs(
torch.log(P) - torch.log(Q)))
while max_log_dist(P_prime, P[0]) > self.gamma/2 + 1e-4:
P_prime = torch.minimum(P_prime, np.exp(self.gamma/2)*P[0])
P_prime = torch.maximum(P_prime, np.exp(-self.gamma/2)*P[0])
P_prime = P_prime / torch.sum(P_prime)
Q.append(P_prime)
i += 1
return Q