in tensorflow_hub/tools/module_search/utils.py [0:0]
def knn_errorrate(d, y_train, y_test, k=1):
"""Calculate the knn error rate based on the distance matrix d.
Args:
d: distance matrix
y_train: label vector for the training samples / or matrix per entry in d
y_test: label vector for the test samples
k: number of direct neighbors for knn or list of multiple k's to evaluate
Returns:
knn error rate (1 - accuracy) for every k provided in descending order
"""
return_list = True
if not isinstance(k, list):
return_list = False
k = [k]
k = sorted(set(k), reverse=True)
num_elements = np.shape(d)[0]
val_k = k[0]
if val_k == 1:
if len(k) > 1:
raise ValueError("No smaller value than '1' allowed for k")
indices = np.argmin(d, axis=1)
cnt = 0
for idx, val in enumerate(indices):
if len(np.shape(y_train)) == 1:
if y_test[idx] != y_train[val]:
cnt += 1
else:
if y_test[idx] != y_train[idx, val]:
cnt += 1
res = float(cnt) / num_elements
else:
indices = np.argpartition(d, val_k - 1, axis=1)
cnt = 0
for i in range(num_elements):
# Get max vote
if len(np.shape(y_train)) == 1:
labels = y_train[indices[i, :val_k]]
else:
labels = y_train[i, indices[i, :val_k]]
keys, counts = np.unique(labels, return_counts=True)
# alternative if multiple max voting neighbors:
# maxcnts = np.where(counts == counts.max())[0]
# found = False
# for idx in maxcnts:
# if y_test[i] == keys[idx]:
# found = True
# break
# if found:
# cnt += (len(maxcnts) - 1.0)/float(len(maxcnts))
# else:
# cnt += 1
maxkey = keys[np.argmax(counts)]
if y_test[i] != maxkey:
cnt += 1
res = float(cnt) / num_elements
if len(k) > 1:
# update sub_d and y_train_new if needed
num_rows = indices[:, :val_k].shape[0]
num_cols = indices[:, :val_k].shape[1]
rows = [x for x in range(num_rows) for _ in range(num_cols)]
cols = indices[:, :val_k].reshape(-1)
sub_d = d[rows, cols].reshape(num_rows, -1)
y_train_new = indices[:, :val_k]
for i in range(num_elements):
if len(np.shape(y_train)) == 1:
y_train_new[i, :] = y_train[y_train_new[i, :]]
else:
y_train_new[i, :] = y_train[i, y_train_new[i, :]]
if not return_list:
return res
res = [res]
if len(k) > 1:
res.extend(knn_errorrate(sub_d, y_train_new, y_test, k[1:]))
return res