def knn_errorrate()

in tensorflow_hub/tools/module_search/utils.py [0:0]


def knn_errorrate(d, y_train, y_test, k=1):
  """Calculate the knn error rate based on the distance matrix d.

  Args:
    d: distance matrix
    y_train: label vector for the training samples / or matrix per entry in d
    y_test: label vector for the test samples
    k: number of direct neighbors for knn or list of multiple k's to evaluate

  Returns:
    knn error rate (1 - accuracy) for every k provided in descending order
  """

  return_list = True
  if not isinstance(k, list):
    return_list = False
    k = [k]

  k = sorted(set(k), reverse=True)
  num_elements = np.shape(d)[0]

  val_k = k[0]
  if val_k == 1:
    if len(k) > 1:
        raise ValueError("No smaller value than '1' allowed for k")

    indices = np.argmin(d, axis=1)

    cnt = 0
    for idx, val in enumerate(indices):

      if len(np.shape(y_train)) == 1:
        if y_test[idx] != y_train[val]:
          cnt += 1
      else:
        if y_test[idx] != y_train[idx, val]:
          cnt += 1

    res = float(cnt) / num_elements

  else:
    indices = np.argpartition(d, val_k - 1, axis=1)
    cnt = 0
    for i in range(num_elements):

      # Get max vote
      if len(np.shape(y_train)) == 1:
        labels = y_train[indices[i, :val_k]]
      else:
        labels = y_train[i, indices[i, :val_k]]
      keys, counts = np.unique(labels, return_counts=True)

      # alternative if multiple max voting neighbors:
      # maxcnts = np.where(counts == counts.max())[0]
      # found = False
      # for idx in maxcnts:
      #   if y_test[i] == keys[idx]:
      #     found = True
      #     break
      # if found:
      #   cnt += (len(maxcnts) - 1.0)/float(len(maxcnts))
      # else:
      #   cnt += 1

      maxkey = keys[np.argmax(counts)]
      if y_test[i] != maxkey:
        cnt += 1

    res = float(cnt) / num_elements

    if len(k) > 1:
      # update sub_d and y_train_new if needed
      num_rows = indices[:, :val_k].shape[0]
      num_cols = indices[:, :val_k].shape[1]
      rows = [x for x in range(num_rows) for _ in range(num_cols)]
      cols = indices[:, :val_k].reshape(-1)
      sub_d = d[rows, cols].reshape(num_rows, -1)

      y_train_new = indices[:, :val_k]
      for i in range(num_elements):
        if len(np.shape(y_train)) == 1:
          y_train_new[i, :] = y_train[y_train_new[i, :]]
        else:
          y_train_new[i, :] = y_train[i, y_train_new[i, :]]

  if not return_list:
    return res

  res = [res]
  if len(k) > 1:
      res.extend(knn_errorrate(sub_d, y_train_new, y_test, k[1:]))

  return res