[Pytorch] Performance Evaluation of a Classification Model-Confusion Matrix
There are several ways to evaluate the performance of a classification model. One of them is a ‘Confusion Matrix’ which classifies our predictions into several groups depending on the model’s prediction and its actual class. Through calculating confusion matrix, we can get the model’s accuracy, sensitivity, specificity, positive predictive value(PPV), negative predictive value(NPV) and F1 score, which are useful performance indicators of the classifier.
This is the example confusion matrix(2*2) of a binary classifier. (If the number of model’s classes is n, shape of the confusion matrix is n*n)
Let’s define some basic terminologies.
- True Positive(TP): The model predicted ‘Positive’ and it’s actual class is ‘Positive’, which is ‘True’
- False Positive(FP): The model predicted ‘Positive’ and it’s actual class is ‘Negative’, which is ‘False’
- False Negative(FN): The model predicted ‘Negative’ and it’s actual class is ‘Positive’, which is ‘False’
- True Negative(TN): The model predicted ‘Negative’ and it’s actual class is ‘Negative’, which is ‘True’
These are the performance criteria calculated from the confusion matrix.
(P=TP+FN, N=TN+FP)
- Accuracy: (TP+TN)/(P+N)
- Sensitivity: TP/P
- Specificity: TN/N
- PPV: TP/(TP+FP)
- NPV: TN/(TN+FN)
- F1 score: 2*(PPV*Sensitivity)/(PPV+Sensitivity) =(2*TP)/(2*TP+FP+FN)
Then, there’s Pytorch codes to calculate confusion matrix and its accuracy, sensitivity, specificity, PPV and NPV of binary classifier.
from sklearn.metrics import confusion_matrix
import torch.nn.functional as Fdef test_model(model,dataloaders,device):
CM=0
model.eval()
with torch.no_grad():
for data in dataloaders['test']:
images, labels, file_name = data
images = images.to(device)
labels = labels.to(device)
outputs = model(images) #file_name
preds = torch.argmax(outputs.data, 1) CM+=confusion_matrix(labels.cpu(), preds.cpu(),labels=[0,1])
tn=CM[0][0]
tp=CM[1][1]
fp=CM[0][1]
fn=CM[1][0]
acc=np.sum(np.diag(CM)/np.sum(CM))
sensitivity=tp/(tp+fn)
precision=tp/(tp+fp)
print('\nTestset Accuracy(mean): %f %%' % (100 * acc))
print()
print('Confusion Matirx : ')
print(CM)
print('- Sensitivity : ',(tp/(tp+fn))*100)
print('- Specificity : ',(tn/(tn+fp))*100)
print('- Precision: ',(tp/(tp+fp))*100)
print('- NPV: ',(tn/(tn+fn))*100)
print('- F1 : ',((2*sensitivity*precision)/(sensitivity+precision))*100)
print()
return acc, CM