Confusion Matrix 101: Understanding Precision and Recall for Machine Learning Beginners

When we build a machine learning model, choosing the appropriate metric is the key factor for the feasibility of that model. This factor is only realizable when a model is implemented into a practical use-case. While accuracy remains the suitable metric for linear regression problems, it does not work well with classification problems. Instead, the performance of a classifier model is better analyzed using a confusion matrix and the derived metrics called precision and recall.

In this blog, we will understand confusion matrix, precision and recall, and how to use them to analyze the model performance.

The Problem with Accuracy for Classifiers

Consider an MNIST example where we are given a dataset of hand-written digits. We use this data to build a classifier that only recognizes the digit “1” and flags it as true, hence this is a binary classification problem.

from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
import matplotlib.pyplot as plt
import matplotlib as mpl

mnist = fetch_openml("mnist_784", version=1)
X, y = mnist["data"], mnist["target"]
y = y=="1"
y.shape

# Plotting an example from a dataset
plt.imshow(X.iloc[0].values.reshape(28,28))

confusion matrix

Because each sample is a series of pixel values, it is reshaped into a 28×28 matrix in the imshow function of matplotlib.

Now, we split the data into training and test sets. The split strategy for this example will be the good ol’ stratified split due to its characteristic of maintaining the proportion of the classes. Take a look at How to sample data *CORRECTLY* into test set? to understand why stratified split is the preferred way to split data.

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(X, mnist["target"]):
    strat_train_set = X.loc[train_index]
    strat_train_labels = y.loc[train_index]
    strat_test_set = X.loc[test_index]
    strat_test_labels = y.loc[test_index]

One general rule while building a machine learning model is to keep the test set completely aside and only use it for testing when your model is ready to be deployed. It is better to use cross-validation to analyze the model performance while we are still working it. During cross-validation, the training set is divided into a number of “folds”. The model is trained over all but one fold which is then used to evaluate the model. Let us say that we choose the stochastic gradient descent classifier for this binary classification problem.

Sckit-learn provides an API called StratifiedKFold to generate user-defined number of folds and carry out cross-validation for us. The code below demonstrates this and we ultimately monitor the model performance with the help of the test set for cross-validation. The n_splits argument defines the number of splits we want to create. Hence, for 3 iterations, the function will create 3 folds each time where 2 will be used for training and 1 for evaluation. We print the accuracy at the end of each iteration: [Note that this test set is a small subset of the training set and not the “original” one that we had kept aside during the initial train-test split.]

classifier_model = SGDClassifier(random_state=42)
skfolds = StratifiedKFold(n_splits=3)

for train_index, test_index in skfolds.split(strat_train_set, strat_train_labels):
    train_set = strat_train_set.iloc[train_index]
    train_labels = strat_train_labels.iloc[train_index]
    test_set = strat_train_set.iloc[test_index]
    test_labels = strat_train_labels.iloc[test_index]

    model_clone = clone(classifier_model)
    model_clone.fit(train_set, train_labels)
    predictions = model_clone.predict(test_set)
    accuracy = sum(predictions==test_labels)/(len(test_labels))
    print(accuracy)

The accuracy we achieve is approximately 0.98 (98%). This sounds great, doesn’t it? Our model is now ready to classify if a hand-written digit is ‘1’ or not! Well, not really. Let us first understand the proportion of images that are labeled as ‘1’.

print(sum(y==True)/len(y))
[Output]:
0.11252857142857142

Oh oh! Only 11.2% of the dataset is labeled as ‘1’ or ‘true’ as in our case; the rest is ‘false’. That means the 98% accuracy does not give us any information of the model specifically recognizing the digit “1”. This is where the terms precision and recall come in.

Confusion Matrix

In a classification problem as such, it is better to understand if the model can predict some specific label or not. In other words, we monitor the following metrics to understand model performance: true positive (prediction matches true ground truth), false positive (prediction is true but ground truth label is false), false negative (prediction is false but the ground truth is true), true negative (prediction is false and so is the ground truth label). The table that consists of these metrics is called confusion matrix:

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix

model_predictions = cross_val_predict(classifier_model, strat_train_set, strat_train_labels, cv=3)

confusion_matrix(strat_train_labels, model_predictions)
[Output]:
array([[49202,   496],
       [  363,  5939]])

Now, each row in a confusion matrix represents the ground truth, while each column represents the predicted class. Hence, in the confusion matrix above, we conclude the following:

  • 49202 are true negatives or the number of instances where the actual digit is not “1” and the model correctly predicted it as such.
  • There are 496 false positives, which means that the model predicted these images as ” 1″ whereas, in the actual data set, they were not labeled as “1”.
  • There are 363 false negatives when prediction is “not 1” as opposed to the ground truth, and
  • 5939 number of true positives where the model correctly predicted these samples as “1”.

Precision and Recall

The confusion matrix informs us how much the model’s prediction matches the ground truth. A perfect model would have 0 false positives and 0 false negatives, but we don’t live in an ideal world, do we? To gain a more comprehensible metric, we look at the precision of this model.

Precision is defined as the ratio of true positives and the sum of true positives and false positives.

\(\text{Precision} = \frac{\text{TP}}{\text{TP + FP}}\)

If a model has 100% precision, this means that no prediction is falsely predicted as true. But what if the model falsely predicts samples as false? This is where we introduce an additional metric called recall.

Recall is the ratio of true positives and the sum of true positives and false negatives.

\(\text{Recall} = \frac{\text{TP}}{\text{TP + FN}}\)

Recall provides insight into the model’s performance in predicting false negatives.

from sklearn.metrics import precision_score, recall_score

print(precision_score(strat_train_labels, model_predictions))
print(recall_score(strat_train_labels, model_predictions))

Now, pause and take a moment to understand the difference between precision and recall. Try to think of a possible use-case for both of them before proceeding.

Precision or Recall: When to use what?

A potential use-case for precision would be a scenario where a model should classify between rotten fruits and “not-rotten” fruits. If a company manufactures a fruit-based product, it would prefer to use a model that maintains a low rate of false positives and does not allow bad fruits to pass into the machine. False negatives, on the other hand, would not be an issue as good quality fruits that are rejected could be added again back to the machine through a manual check.

Conversely, when you are working on an automated-driving vehicle and want to develop a function where the vehicle stops at red lights, “true negatives” is the concerned parameter. You do not want the vehicle to predict an actual red light as false, hence recall would be the metric of focus in such case.

Summary

In this blog, we understood a method for evaluating classification models, i.e., confusion matrix. The matrix provides a comparison between the ground truth and model predictions. Based on the problem statement, we can then calculate the precision or the recall to check if the model’s performance is suited for the application.

If you enjoyed this blog and found this helpful, support me by following me on social media:

To stay updated on such interesting posts, consider subscribing to the free, monthly newsletter where you’ll receive a reminder e-mail on the interesting machine learning posts that you might have missed:

Leave a Reply