Dassl.pytorch/dassl/modeling/ops/mmd.py (71 lines of code) (raw):
import torch
import torch.nn as nn
from torch.nn import functional as F
class MaximumMeanDiscrepancy(nn.Module):
def __init__(self, kernel_type="rbf", normalize=False):
super().__init__()
self.kernel_type = kernel_type
self.normalize = normalize
def forward(self, x, y):
# x, y: two batches of data with shape (batch, dim)
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
if self.normalize:
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
if self.kernel_type == "linear":
return self.linear_mmd(x, y)
elif self.kernel_type == "poly":
return self.poly_mmd(x, y)
elif self.kernel_type == "rbf":
return self.rbf_mmd(x, y)
else:
raise NotImplementedError
def linear_mmd(self, x, y):
# k(x, y) = x^T y
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_xy = torch.mm(x, y.t())
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
# k(x, y) = (alpha * x^T y + c)^d
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
k_xx = (alpha*k_xx + c).pow(d)
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
k_yy = (alpha*k_yy + c).pow(d)
k_xy = torch.mm(x, y.t())
k_xy = (alpha*k_xy + c).pow(d)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
def rbf_mmd(self, x, y):
# k_xx
d_xx = self.euclidean_squared_distance(x, x)
d_xx = self.remove_self_distance(d_xx)
k_xx = self.rbf_kernel_mixture(d_xx)
# k_yy
d_yy = self.euclidean_squared_distance(y, y)
d_yy = self.remove_self_distance(d_yy)
k_yy = self.rbf_kernel_mixture(d_yy)
# k_xy
d_xy = self.euclidean_squared_distance(x, y)
k_xy = self.rbf_kernel_mixture(d_xy)
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
@staticmethod
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
K = 0
for sigma in sigmas:
gamma = 1.0 / (2.0 * sigma**2)
K += torch.exp(-gamma * exponent)
return K
@staticmethod
def remove_self_distance(distmat):
tmp_list = []
for i, row in enumerate(distmat):
row1 = torch.cat([row[:i], row[i + 1:]])
tmp_list.append(row1)
return torch.stack(tmp_list)
@staticmethod
def euclidean_squared_distance(x, y):
m, n = x.size(0), y.size(0)
distmat = (
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
# distmat.addmm_(1, -2, x, y.t())
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
return distmat
if __name__ == "__main__":
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
d = mmd(input1, input2)
print(d.item())