def compute_distance_matrix()

in tensorflow_hub/tools/module_search/utils.py [0:0]


def compute_distance_matrix(x_train, x_test, measure="squared_l2"):
  """Calculates the distance matrix between test and train.

  Args:
    x_train: Matrix (NxD) where each row represents a training sample
    x_test: Matrix (MxD) where each row represents a test sample
    measure: Distance measure (not necessarly metric) to use

  Raises:
    NotImplementedError: When the measure is not implemented

  Returns:
    Matrix (MxN) where elemnt i,j is the distance between
    x_test_i and x_train_j.
  """

  if tf.test.is_gpu_available():
    x_train = tf.convert_to_tensor(x_train, tf.float32)
    x_test = tf.convert_to_tensor(x_test, tf.float32)
  else:
    if x_train.dtype != np.float32:
      x_train = np.float32(x_train)
    if x_test.dtype != np.float32:
      x_test = np.float32(x_test)

  if measure == "squared_l2":
    if tf.test.is_gpu_available():
      x_xt = tf.matmul(x_test, tf.transpose(x_train)).numpy()

      x_train_2 = tf.reduce_sum(tf.math.square(x_train), 1).numpy()
      x_test_2 = tf.reduce_sum(tf.math.square(x_test), 1).numpy()
    else:
      x_xt = np.matmul(x_test, np.transpose(x_train))

      x_train_2 = np.sum(np.square(x_train), axis=1)
      x_test_2 = np.sum(np.square(x_test), axis=1)

    for i in range(np.shape(x_xt)[0]):
      x_xt[i, :] = np.multiply(x_xt[i, :], -2)
      x_xt[i, :] = np.add(x_xt[i, :], x_test_2[i])
      x_xt[i, :] = np.add(x_xt[i, :], x_train_2)

  elif measure == "cosine":
    if tf.test.is_gpu_available():
      x_xt = tf.matmul(x_test, tf.transpose(x_train)).numpy()

      x_train_2 = tf.linalg.norm(x_train, axis=1).numpy()
      x_test_2 = tf.linalg.norm(x_test, axis=1).numpy()
    else:
      x_xt = np.matmul(x_test, np.transpose(x_train))

      x_train_2 = np.linalg.norm(x_train, axis=1)
      x_test_2 = np.linalg.norm(x_test, axis=1)

    outer = np.outer(x_test_2, x_train_2)
    x_xt = np.ones(np.shape(x_xt)) - np.divide(x_xt, outer)

  else:
    raise NotImplementedError("Method '{}' is not implemented".format(measure))

  return x_xt