def discrepancy_slice_wasserstein()

in swd.py [0:0]


def discrepancy_slice_wasserstein(p1, p2):
    s = array_ops.shape(p1)
    if p1.get_shape().as_list()[1] > 1:
        # For data more than one-dimensional, perform multiple random projection to 1-D
        proj = random_ops.random_normal([array_ops.shape(p1)[1], 128])
        proj *= math_ops.rsqrt(math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True))
        p1 = math_ops.matmul(p1, proj)
        p2 = math_ops.matmul(p2, proj)
    p1 = sort_rows(p1, s[0])
    p2 = sort_rows(p2, s[0])
    wdist = math_ops.reduce_mean(math_ops.square(p1 - p2))
    return math_ops.reduce_mean(wdist)