def mix()

in submix.py [0:0]


    def mix(self, p, p_prime, lamb=0.5):
        mix = lamb*p + (1-lamb)*p_prime + 1e-20
        mix = mix/torch.sum(mix)
        assert (torch.sum(mix).item() -1.0)**2 <1e-10, 'this is not a pmf'
        return mix