def max_log_projection()

in submix.py [0:0]


    def max_log_projection(self, P):
        #project all P into max-log ball around P0
        i = 1
        Q = [ P[0] ]
        for _ in range(len(P)-1):
            P_prime = torch.clone(P[i])
            max_log_dist = lambda P,Q : torch.max(torch.abs(
                torch.log(P) - torch.log(Q)))
            while max_log_dist(P_prime, P[0]) > self.gamma/2 + 1e-4:
                P_prime = torch.minimum(P_prime, np.exp(self.gamma/2)*P[0])
                P_prime = torch.maximum(P_prime, np.exp(-self.gamma/2)*P[0])
                P_prime = P_prime / torch.sum(P_prime)
            Q.append(P_prime)
            i += 1
        return Q