Dassl.pytorch/dassl/modeling/ops/optimal_transport.py (121 lines of code) (raw):

import torch import torch.nn as nn from torch.nn import functional as F class OptimalTransport(nn.Module): @staticmethod def distance(batch1, batch2, dist_metric="cosine"): if dist_metric == "cosine": batch1 = F.normalize(batch1, p=2, dim=1) batch2 = F.normalize(batch2, p=2, dim=1) dist_mat = 1 - torch.mm(batch1, batch2.t()) elif dist_metric == "euclidean": m, n = batch1.size(0), batch2.size(0) dist_mat = ( torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t() ) dist_mat.addmm_( 1, -2, batch1, batch2.t() ) # squared euclidean distance elif dist_metric == "fast_euclidean": batch1 = batch1.unsqueeze(-2) batch2 = batch2.unsqueeze(-3) dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1) else: raise ValueError( "Unknown cost function: {}. Expected to " "be one of [cosine | euclidean]".format(dist_metric) ) return dist_mat class SinkhornDivergence(OptimalTransport): thre = 1e-3 def __init__( self, dist_metric="cosine", eps=0.01, max_iter=5, bp_to_sinkhorn=False ): super().__init__() self.dist_metric = dist_metric self.eps = eps self.max_iter = max_iter self.bp_to_sinkhorn = bp_to_sinkhorn def forward(self, x, y): # x, y: two batches of data with shape (batch, dim) W_xy = self.transport_cost(x, y) W_xx = self.transport_cost(x, x) W_yy = self.transport_cost(y, y) return 2*W_xy - W_xx - W_yy def transport_cost(self, x, y, return_pi=False): C = self.distance(x, y, dist_metric=self.dist_metric) pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre) if not self.bp_to_sinkhorn: pi = pi.detach() cost = torch.sum(pi * C) if return_pi: return cost, pi return cost @staticmethod def sinkhorn_iterate(C, eps, max_iter, thre): nx, ny = C.shape mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx) nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny) u = torch.zeros_like(mu) v = torch.zeros_like(nu) def M(_C, _u, _v): """Modified cost for logarithmic updates. Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon """ return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps real_iter = 0 # check if algorithm terminates before max_iter # Sinkhorn iterations for i in range(max_iter): u0 = u u = eps * ( torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1) ) + u v = ( eps * ( torch.log(nu + 1e-8) - torch.logsumexp(M(C, u, v).permute(1, 0), dim=1) ) + v ) err = (u - u0).abs().sum() real_iter += 1 if err.item() < thre: break # Transport plan pi = diag(a)*K*diag(b) return torch.exp(M(C, u, v)) class MinibatchEnergyDistance(SinkhornDivergence): def __init__( self, dist_metric="cosine", eps=0.01, max_iter=5, bp_to_sinkhorn=False ): super().__init__( dist_metric=dist_metric, eps=eps, max_iter=max_iter, bp_to_sinkhorn=bp_to_sinkhorn, ) def forward(self, x, y): x1, x2 = torch.split(x, x.size(0) // 2, dim=0) y1, y2 = torch.split(y, y.size(0) // 2, dim=0) cost = 0 cost += self.transport_cost(x1, y1) cost += self.transport_cost(x1, y2) cost += self.transport_cost(x2, y1) cost += self.transport_cost(x2, y2) cost -= 2 * self.transport_cost(x1, x2) cost -= 2 * self.transport_cost(y1, y2) return cost if __name__ == "__main__": # example: https://dfdazac.github.io/sinkhorn.html import numpy as np n_points = 5 a = np.array([[i, 0] for i in range(n_points)]) b = np.array([[i, 1] for i in range(n_points)]) x = torch.tensor(a, dtype=torch.float) y = torch.tensor(b, dtype=torch.float) sinkhorn = SinkhornDivergence( dist_metric="euclidean", eps=0.01, max_iter=5 ) dist, pi = sinkhorn.transport_cost(x, y, True) import pdb pdb.set_trace()