in tensorflow_similarity/visualization/confusion_matrix_viz.py [0:0]
def confusion_matrix(y_pred: IntTensor,
y_true: IntTensor,
normalize: bool = True,
labels: IntTensor = None,
title: str = 'Confusion matrix',
cmap: str = 'Blues',
show: bool = True) -> Tuple[Any, FloatTensor]:
"""Plot confusion matrix
Args:
y_pred: Model prediction returned by `model.match()`
y_true: Expected class_id.
normalize: Normalizes matrix values between 0 and 1.
Defaults to True.
labels: List of class string label to display instead of the class
numerical ids. Defaults to None.
title: Title of the confusion matrix. Defaults to 'Confusion matrix'.
cmap: Color schema as CMAP. Defaults to 'Blues'.
show: If the plot is going to be shown or not. Defaults to True.
Returns:
A Tuple containing the plot and confusion matrix.
"""
with tf.device("/cpu:0"):
# Ensure we are working with integer tensors.
if not tf.is_tensor(y_pred):
y_pred = tf.convert_to_tensor(np.array(y_pred))
y_pred = tf.cast(y_pred, dtype='int32')
if not tf.is_tensor(y_true):
y_true = tf.convert_to_tensor(np.array(y_true))
y_true = tf.cast(y_true, dtype='int32')
cm = tf.math.confusion_matrix(y_true, y_pred)
cm = tf.cast(cm, dtype='float')
accuracy = tf.linalg.trace(cm) / tf.math.reduce_sum(cm)
misclass = 1 - accuracy
if normalize:
cm = tf.math.divide_no_nan(
cm,
tf.math.reduce_sum(cm, axis=1)[:, np.newaxis]
)
f, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.set_title(title)
f.colorbar(im)
if labels is not None:
tick_marks = np.arange(len(labels))
ax.set_xticks(tick_marks)
ax.set_xticklabels(labels, rotation=45)
ax.set_yticks(tick_marks)
ax.set_yticklabels(labels)
cm_max = tf.math.reduce_max(cm)
thresh = cm_max / 1.5 if normalize else cm_max / 2.0
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
val = cm[i, j]
color = "white" if val > thresh else "black"
txt = "%.2f" % val if val > 0.0 else "0"
ax.text(j, i, txt, horizontalalignment="center", color=color)
f.tight_layout()
ax.set_ylabel('True label')
ax.set_xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(
accuracy, misclass))
if show:
plt.show()
return ax, cm