in submix.py [0:0]
def mix(self, p, p_prime, lamb=0.5): mix = lamb*p + (1-lamb)*p_prime + 1e-20 mix = mix/torch.sum(mix) assert (torch.sum(mix).item() -1.0)**2 <1e-10, 'this is not a pmf' return mix