def compute_dist()

in utils/loss_fn.py [0:0]


def compute_dist(array1, array2, type='euclidean'):
    """Compute the euclidean or cosine distance of all pairs.
    Args:
        array1: numpy array with shape [m1, n]
        array2: numpy array with shape [m2, n]
        type: one of ['cosine', 'euclidean']
    Returns:
        numpy array with shape [m1, m2]
    """
    assert type in ['cosine', 'euclidean']
    if type == 'cosine':
        array1 = normalize(array1, axis=1)
        array2 = normalize(array2, axis=1)
        dist = np.matmul(array1, array2.T)
        return dist
    else:
        # # shape [m1, 1]
        # square1 = np.sum(np.square(array1), axis=1)[..., np.newaxis]
        # # shape [1, m2]
        # square2 = np.sum(np.square(array2), axis=1)[np.newaxis, ...]
        # squared_dist = - 2 * np.matmul(array1, array2.T) + square1 + square2
        # squared_dist[squared_dist < 0] = 0
        # dist = np.sqrt(squared_dist)

        # shape [m1, 1]
        square1 = torch.unsqueeze(torch.sum(torch.square(array1), axis=1), 1)
        # shape [1, m2]
        square2 = torch.unsqueeze(torch.sum(torch.square(array2), axis=1), 0)
        squared_dist = - 2 * torch.matmul(array1, array2.T) + square1 + square2
        squared_dist[squared_dist < 0] = 0
        dist = torch.sqrt(squared_dist)
        return dist