import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix def plot_confusion_matrix(actual, pred, threshold=15, palette=sns.light_palette('green')): C = confusion_matrix(actual, pred) print("Number of misclassified points ",(len(actual)-np.trace(C))/len(actual)*100) A = (((C.T)/(C.sum(axis=1))).T) B =(C/C.sum(axis=0)) labels = list(np.unique(actual)) if len(labels) > threshold: raise AttributeError(f"Threshold value should be less than {threshold}") cmap=palette # representing A in heatmap format print("-"*50, "Confusion matrix", "-"*50) plt.figure(figsize=(10,5)) sns.heatmap(C, annot=True, cmap=cmap, fmt=".3f", xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted Class') plt.ylabel('Original Class') plt.show() print("-"*50, "Precision matrix", "-"*50) plt.figure(figsize=(10,5)) sns.heatmap(B, annot=True, cmap=cmap, fmt=".3f", xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted Class') plt.ylabel('Original Class') plt.show() print("Sum of columns in precision matrix",B.sum(axis=0)) # representing B in heatmap format print("-"*50, "Recall matrix" , "-"*50) plt.figure(figsize=(10,5)) sns.heatmap(A, annot=True, cmap=cmap, fmt=".3f", xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted Class') plt.ylabel('Original Class') plt.show()