in swd.py [0:0]
def discrepancy_slice_wasserstein(p1, p2):
s = array_ops.shape(p1)
if p1.get_shape().as_list()[1] > 1:
# For data more than one-dimensional, perform multiple random projection to 1-D
proj = random_ops.random_normal([array_ops.shape(p1)[1], 128])
proj *= math_ops.rsqrt(math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True))
p1 = math_ops.matmul(p1, proj)
p2 = math_ops.matmul(p2, proj)
p1 = sort_rows(p1, s[0])
p2 = sort_rows(p2, s[0])
wdist = math_ops.reduce_mean(math_ops.square(p1 - p2))
return math_ops.reduce_mean(wdist)