Chapter 7 Classification II: evaluation & tuning

7.1 Overview

This chapter continues the introduction to predictive modelling through classification. While the previous chapter covered training and data preprocessing, this chapter focuses on how to split data, how to evaluate prediction accuracy, and how to choose model parameters to maximize performance.

7.2 Chapter learning objectives

By the end of the chapter, students will be able to:

  • Describe what training, validation, and test data sets are and how they are used in classification
  • Split data into training, validation, and test data sets
  • Evaluate classification accuracy in R using a validation data set and appropriate metrics
  • Execute cross-validation in R to choose the number of neighbours in a K-nearest neighbour classifier
  • Describe advantages and disadvantages of the K-nearest neighbour classification algorithm

7.3 Evaluating accuracy

Sometimes our classifier might make the wrong prediction. A classifier does not need to be right 100% of the time to be useful, though we don’t want the classifier to make too many wrong predictions. How do we measure how “good” our classifier is? Let’s revisit the breast cancer images example and think about how our classifier will be used in practice. A biopsy will be performed on a new patient’s tumour, the resulting image will be analyzed, and the classifier will be asked to decide whether the tumour is benign or malignant. The key word here is new: our classifier is “good” if it provides accurate predictions on data not seen during training. But then how can we evaluate our classifier without having to visit the hospital to collect more tumour images?

The trick is to split up the data set into a training set and test set, and only show the classifier the training set when building the classifier. Then to evaluate the accuracy of the classifier, we can use it to predict the labels (which we know) in the test set. If our predictions match the true labels for the observations in the test set very well, then we have some confidence that our classifier might also do a good job of predicting the class labels for new observations that we do not have the class labels for.

Note: if there were a golden rule of machine learning, it might be this: you cannot use the test data to build the model! If you do, the model gets to “see” the test data in advance, making it look more accurate than it really is. Imagine how bad it would be to overestimate your classifier’s accuracy when predicting whether a patient’s tumour is malignant or benign!

Splitting the data into training and testing sets

Figure 7.1: Splitting the data into training and testing sets

How exactly can we assess how well our predictions match the true labels for the observations in the test set? One way we can do this is to calculate the prediction accuracy. This is the fraction of examples for which the classifier made the correct prediction. To calculate this we divide the number of correct predictions by the number of predictions made. Other measures for how well our classifier performed include precision and recall; these will not be discussed here, but you will encounter them in other more advanced courses on this topic. This process is illustrated below:

Process for splitting the data and finding the prediction accuracy

Figure 7.2: Process for splitting the data and finding the prediction accuracy

In R, we can use the tidymodels library collection not only to perform K-nearest neighbour classification, but also to assess how well our classification worked. Let’s start by loading the necessary libraries, reading in the breast cancer data from the previous chapter, and making a quick scatter plot visualization of tumour cell concavity versus smoothness coloured by diagnosis.

# load packages
library(tidyverse)
library(tidymodels)

# load data
cancer <- read_csv("data/unscaled_wdbc.csv") %>%
  mutate(Class = as_factor(Class)) # convert the character Class variable to the factor datatype

# colour palette
cbPalette <- c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")

# create scatter plot of tumour cell concavity versus smoothness,
# labelling the points be diagnosis class
perim_concav <- cancer %>%
  ggplot(aes(x = Smoothness, y = Concavity, color = Class)) +
  geom_point(alpha = 0.5) +
  labs(color = "Diagnosis") +
  scale_color_manual(labels = c("Malignant", "Benign"), values = cbPalette)

perim_concav
Scatterplot of tumour cell concavity versus smoothness coloured by diagnosis label

Figure 7.3: Scatterplot of tumour cell concavity versus smoothness coloured by diagnosis label

1. Create the train / test split

Once we have decided on a predictive question to answer and done some preliminary exploration, the very next thing to do is to split the data into the training and test sets. Typically, the training set is between 50 - 100% of the data, while the test set is the remaining 0 - 50%; the intuition is that you want to trade off between training an accurate model (by using a larger training data set) and getting an accurate evaluation of its performance (by using a larger test data set). Here, we will use 75% of the data for training, and 25% for testing. To do this we will use the initial_split function, specifying that prop = 0.75 and the target variable is Class:

set.seed(1)
cancer_split <- initial_split(cancer, prop = 0.75, strata = Class)
cancer_train <- training(cancer_split)
cancer_test <- testing(cancer_split)

Note: You will see in the code above that we use the set.seed function again, as discussed in the previous chapter. In this case it is because initial_split uses random sampling to choose which rows will be in the training set. Since we want our code to be reproducible and generate the same train/test split each time it is run, we use set.seed.

glimpse(cancer_train)
## Rows: 427
## Columns: 12
## $ ID                <dbl> 842302, 842517, 84300903, 84348301, 84358402, 843786, 844359, 84458202…
## $ Class             <fct> M, M, M, M, M, M, M, M, M, M, M, M, M, M, M, B, B, M, M, M, M, M, M, M…
## $ Radius            <dbl> 17.990, 20.570, 19.690, 11.420, 20.290, 12.450, 18.250, 13.710, 12.460…
## $ Texture           <dbl> 10.38, 17.77, 21.25, 20.38, 14.34, 15.70, 19.98, 20.83, 24.04, 23.24, …
## $ Perimeter         <dbl> 122.80, 132.90, 130.00, 77.58, 135.10, 82.57, 119.60, 90.20, 83.97, 10…
## $ Area              <dbl> 1001.0, 1326.0, 1203.0, 386.1, 1297.0, 477.1, 1040.0, 577.9, 475.9, 79…
## $ Smoothness        <dbl> 0.11840, 0.08474, 0.10960, 0.14250, 0.10030, 0.12780, 0.09463, 0.11890…
## $ Compactness       <dbl> 0.27760, 0.07864, 0.15990, 0.28390, 0.13280, 0.17000, 0.10900, 0.16450…
## $ Concavity         <dbl> 0.30010, 0.08690, 0.19740, 0.24140, 0.19800, 0.15780, 0.11270, 0.09366…
## $ Concave_Points    <dbl> 0.14710, 0.07017, 0.12790, 0.10520, 0.10430, 0.08089, 0.07400, 0.05985…
## $ Symmetry          <dbl> 0.2419, 0.1812, 0.2069, 0.2597, 0.1809, 0.2087, 0.1794, 0.2196, 0.2030…
## $ Fractal_Dimension <dbl> 0.07871, 0.05667, 0.05999, 0.09744, 0.05883, 0.07613, 0.05742, 0.07451…
glimpse(cancer_test)
## Rows: 142
## Columns: 12
## $ ID                <dbl> 844981, 84799002, 848406, 849014, 8510426, 8511133, 853401, 854002, 85…
## $ Class             <fct> M, M, M, M, B, M, M, M, M, M, M, B, B, M, M, M, B, M, M, B, B, B, M, M…
## $ Radius            <dbl> 13.000, 14.540, 14.680, 19.810, 13.540, 15.340, 18.630, 19.270, 13.440…
## $ Texture           <dbl> 21.82, 27.54, 20.13, 22.15, 14.36, 14.26, 25.11, 26.47, 21.58, 20.28, …
## $ Perimeter         <dbl> 87.50, 96.73, 94.74, 130.00, 87.46, 102.50, 124.80, 127.90, 86.18, 87.…
## $ Area              <dbl> 519.8, 658.8, 684.5, 1260.0, 566.3, 704.4, 1088.0, 1162.0, 563.0, 545.…
## $ Smoothness        <dbl> 0.12730, 0.11390, 0.09867, 0.09831, 0.09779, 0.10730, 0.10640, 0.09401…
## $ Compactness       <dbl> 0.19320, 0.15950, 0.07200, 0.10270, 0.08129, 0.21350, 0.18870, 0.17190…
## $ Concavity         <dbl> 0.18590, 0.16390, 0.07395, 0.14790, 0.06664, 0.20770, 0.23190, 0.16570…
## $ Concave_Points    <dbl> 0.093530, 0.073640, 0.052590, 0.094980, 0.047810, 0.097560, 0.124400, …
## $ Symmetry          <dbl> 0.2350, 0.2303, 0.1586, 0.1582, 0.1885, 0.2521, 0.2183, 0.1853, 0.1784…
## $ Fractal_Dimension <dbl> 0.07389, 0.07077, 0.05922, 0.05395, 0.05766, 0.07032, 0.06197, 0.06261…

We can see from glimpse in the code above that the training set contains 427 observations, while the test set contains 142 observations. This corresponds to a train / test split of 75% / 25%, as desired.

2. Pre-process the data

As we mentioned last chapter, K-NN is sensitive to the scale of the predictors, and so we should perform some preprocessing to standardize them. An additional consideration we need to take when doing this is that we should create the standardization preprocessor using only the training data. This ensures that our test data does not influence any aspect of our model training. Once we have created the standardization preprocessor, we can then apply it separately to both the training and test data sets.

Fortunately, the recipe framework from tidymodels makes it simple to handle this properly. Below we construct and prepare the recipe using only the training data (due to data = cancer_train in the first line).

cancer_recipe <- recipe(Class ~ Smoothness + Concavity, data = cancer_train) %>%
  step_scale(all_predictors()) %>%
  step_center(all_predictors())

3. Train the classifier

Now that we have split our original data set into training and test sets, we can create our K-nearest neighbour classifier with only the training set using the technique we learned in the previous chapter. For now, we will just choose the number \(K\) of neighbours to be 3, and use concavity and smoothness as the predictors.

set.seed(1)
knn_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = 3) %>%
  set_engine("kknn") %>%
  set_mode("classification")

knn_fit <- workflow() %>%
  add_recipe(cancer_recipe) %>%
  add_model(knn_spec) %>%
  fit(data = cancer_train)

knn_fit
## ══ Workflow [trained] ═════════════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: nearest_neighbor()
## 
## ── Preprocessor ───────────────────────────────────────────────────────────────────────────────────
## 2 Recipe Steps
## 
## ● step_scale()
## ● step_center()
## 
## ── Model ──────────────────────────────────────────────────────────────────────────────────────────
## 
## Call:
## kknn::train.kknn(formula = formula, data = data, ks = ~3, kernel = ~"rectangular")
## 
## Type of response variable: nominal
## Minimal misclassification: 0.1264637
## Best kernel: rectangular
## Best k: 3

Note: Here again you see the set.seed function. In the K-nearest neighbour algorithm, there is a tie for the majority neighbour class, the winner is randomly selected. Although there is no chance of a tie when \(K\) is odd (here \(K=3\)), it is possible that the code may be changed in the future to have an even value of \(K\). Thus, to prevent potential issues with reproducibility, we have set the seed. Note that in your own code, you only have to set the seed once at the beginning of your analysis.

4. Predict the labels in the test set

Now that we have a K-nearest neighbour classifier object, we can use it to predict the class labels for our test set. We use the bind_cols to add the column of predictions to the original test data, creating the cancer_test_predictions data frame. The Class variable contains the true diagnoses, while the .pred_class contains the predicted diagnoses from the model.

cancer_test_predictions <- predict(knn_fit, cancer_test) %>%
  bind_cols(cancer_test)
cancer_test_predictions
## # A tibble: 142 x 13
##    .pred_class     ID Class Radius Texture Perimeter  Area Smoothness Compactness Concavity
##    <fct>        <dbl> <fct>  <dbl>   <dbl>     <dbl> <dbl>      <dbl>       <dbl>     <dbl>
##  1 M           8.45e5 M       13      21.8      87.5  520.     0.127       0.193     0.186 
##  2 M           8.48e7 M       14.5    27.5      96.7  659.     0.114       0.160     0.164 
##  3 B           8.48e5 M       14.7    20.1      94.7  684.     0.0987      0.072     0.0740
##  4 M           8.49e5 M       19.8    22.2     130   1260      0.0983      0.103     0.148 
##  5 B           8.51e6 B       13.5    14.4      87.5  566.     0.0978      0.0813    0.0666
##  6 M           8.51e6 M       15.3    14.3     102.   704.     0.107       0.214     0.208 
##  7 M           8.53e5 M       18.6    25.1     125.  1088      0.106       0.189     0.232 
##  8 M           8.54e5 M       19.3    26.5     128.  1162      0.0940      0.172     0.166 
##  9 B           8.55e5 M       13.4    21.6      86.2  563      0.0816      0.0603    0.0311
## 10 M           8.56e5 M       13.3    20.3      87.3  545.     0.104       0.144     0.0985
## # … with 132 more rows, and 3 more variables: Concave_Points <dbl>, Symmetry <dbl>,
## #   Fractal_Dimension <dbl>

5. Compute the accuracy

Finally we can assess our classifier’s accuracy. To do this we use the metrics function from tidymodels to get the statistics about the quality of our model, specifying the truth and estimate arguments:

cancer_test_predictions %>%
  metrics(truth = Class, estimate = .pred_class)
## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.880
## 2 kap      binary         0.741

This shows that the accuracy of the classifier on the test data was 88%. We can also look at the confusion matrix for the classifier, which shows the table of predicted labels and correct labels, using the conf_mat function:

cancer_test_predictions %>%
  conf_mat(truth = Class, estimate = .pred_class)
##           Truth
## Prediction  M  B
##          M 43  7
##          B 10 82

This says that the classifier labelled 43+82 = 125 observations correctly, 10 observations as benign when they were truly malignant, and 7 observations as malignant when they were truly benign.

7.4 Tuning the classifier

The vast majority of predictive models in statistics and machine learning have parameters that you have to pick. For example, in the K-nearest neighbour classification algorithm we have been using in the past two chapters, we have had to pick the number of neighbours \(K\) for the class vote. Is it possible to make this selection, i.e., tune the model, in a principled way? Ideally what we want is to somehow maximize the performance of our classifier on data it hasn’t seen yet. So we will play the same trick we did before when evaluating our classifier: we’ll split our overall training data set further into two subsets, called the training set and validation set. We will use the newly-named training set for building the classifier, and the validation set for evaluating it! Then we will try different values of the parameter \(K\) and pick the one that yields the highest accuracy.

Remember: don’t touch the test set during the tuning process. Tuning is a part of model training!

7.4.1 Cross-validation

There is an important detail to mention about the process of tuning: we can, if we want to, split our overall training data up in multiple different ways, train and evaluate a classifier for each split, and then choose the parameter based on all of the different results. If we just split our overall training data once, our best parameter choice will depend strongly on whatever data was lucky enough to end up in the validation set. Perhaps using multiple different train / validation splits, we’ll get a better estimate of accuracy, which will lead to a better choice of the number of neighbours \(K\) for the overall set of training data.

Note: you might be wondering why we can’t we use the multiple splits to test our final classifier after tuning is done. This is simply because at the end of the day, we will produce a single classifier using our overall training data. If we do multiple train / test splits, we will end up with multiple classifiers, each with their own accuracy evaluated on different test data.

Let’s investigate this idea in R! In particular, we will use different seed values in the set.seed function to generate five different train / validation splits of our overall training data, train five different K-nearest neighbour models, and evaluate their accuracy.

accuracies <- c()
for (i in 1:5) {
  set.seed(i) # makes the random selection of rows reproducible

  # create the 25/75 split of the training data into training and validation
  cancer_split <- initial_split(cancer_train, prop = 0.75, strata = Class)
  cancer_subtrain <- training(cancer_split)
  cancer_validation <- testing(cancer_split)

  # recreate the standardization recipe from before (since it must be based on the training data)
  cancer_recipe <- recipe(Class ~ Smoothness + Concavity, data = cancer_subtrain) %>%
    step_scale(all_predictors()) %>%
    step_center(all_predictors())

  # fit the knn model (we can reuse the old knn_spec model from before)
  knn_fit <- workflow() %>%
    add_recipe(cancer_recipe) %>%
    add_model(knn_spec) %>%
    fit(data = cancer_subtrain)

  # get predictions on the validation data
  validation_predicted <- predict(knn_fit, cancer_validation) %>%
    bind_cols(cancer_validation)

  # compute the accuracy
  acc <- validation_predicted %>%
    metrics(truth = Class, estimate = .pred_class) %>%
    filter(.metric == "accuracy") %>%
    select(.estimate) %>%
    pull()
  accuracies <- append(accuracies, acc)
}
accuracies
## [1] 0.9150943 0.8679245 0.8490566 0.8962264 0.9150943

With five different shuffles of the data, we get five different values for accuracy. None of these is necessarily “more correct” than any other; they’re just five estimates of the true, underlying accuracy of our classifier built using our overall training data. We can combine the estimates by taking their average (here 0.8886792) to try to get a single assessment of our classifier’s accuracy; this has the effect of reducing the influence of any one (un)lucky validation set on the estimate.

In practice, we don’t use random splits, but rather use a more structured splitting procedure so that each observation in the data set is used in a validation set only a single time. The name for this strategy is called cross-validation. In cross-validation, we split our overall training data into \(C\) evenly-sized chunks, and then iteratively use \(1\) chunk as the validation set and combine the remaining \(C-1\) chunks as the training set:

5-fold cross validation

Figure 7.4: 5-fold cross validation

In the picture above, \(C=5\) different chunks of the data set are used, resulting in 5 different choices for the validation set; we call this 5-fold cross-validation. To do 5-fold cross-validation in R with tidymodels, we use another function: vfold_cv. This function splits our training data into v folds automatically:

cancer_vfold <- vfold_cv(cancer_train, v = 5, strata = Class)
cancer_vfold
## #  5-fold cross-validation using stratification 
## # A tibble: 5 x 2
##   splits           id   
##   <list>           <chr>
## 1 <split [341/86]> Fold1
## 2 <split [341/86]> Fold2
## 3 <split [341/86]> Fold3
## 4 <split [342/85]> Fold4
## 5 <split [343/84]> Fold5

Then, when we create our data analysis workflow, we use the fit_resamples function instead of the fit function for training. This runs cross-validation on each train/validation split.

Note: we set the seed when we call train not only because of the potential for ties, but also because we are doing cross-validation. Cross-validation uses a random process to select how to partition the training data.

set.seed(1)

# recreate the standardization recipe from before (since it must be based on the training data)
cancer_recipe <- recipe(Class ~ Smoothness + Concavity, data = cancer_train) %>%
  step_scale(all_predictors()) %>%
  step_center(all_predictors())

# fit the knn model (we can reuse the old knn_spec model from before)
knn_fit <- workflow() %>%
  add_recipe(cancer_recipe) %>%
  add_model(knn_spec) %>%
  fit_resamples(resamples = cancer_vfold)

knn_fit
## #  5-fold cross-validation using stratification 
## # A tibble: 5 x 4
##   splits           id    .metrics         .notes          
##   <list>           <chr> <list>           <list>          
## 1 <split [341/86]> Fold1 <tibble [2 × 3]> <tibble [0 × 1]>
## 2 <split [341/86]> Fold2 <tibble [2 × 3]> <tibble [0 × 1]>
## 3 <split [341/86]> Fold3 <tibble [2 × 3]> <tibble [0 × 1]>
## 4 <split [342/85]> Fold4 <tibble [2 × 3]> <tibble [0 × 1]>
## 5 <split [343/84]> Fold5 <tibble [2 × 3]> <tibble [0 × 1]>

The collect_metrics function is used to aggregate the mean and standard error of the classifier’s validation accuracy across the folds. The standard error is a measure of how uncertain we are in the mean value. A detailed treatment of this is beyond the scope of this chapter; but roughly, if your estimated mean (that the collect_metrics function gives you) is 0.88 and standard error is 0.02, you can expect the true average accuracy of the classifier to be somewhere roughly between 0.86 and 0.90 (although it may fall outside this range).

knn_fit %>% collect_metrics()
## # A tibble: 2 x 5
##   .metric  .estimator  mean     n std_err
##   <chr>    <chr>      <dbl> <int>   <dbl>
## 1 accuracy binary     0.883     5  0.0189
## 2 roc_auc  binary     0.923     5  0.0104

We can choose any number of folds, and typically the more we use the better our accuracy estimate will be (lower standard error). However, we are limited by computational power: the more folds we choose, the more computation it takes, and hence the more time it takes to run the analysis. So when you do cross-validation, you need to consider the size of the data, and the speed of the algorithm (e.g., K-nearest neighbour) and the speed of your computer. In practice, this is a trial and error process, but typically \(C\) is chosen to be either 5 or 10. Here we show how the standard error decreases when we use 10-fold cross validation rather than 5-fold:

cancer_vfold <- vfold_cv(cancer_train, v = 10, strata = Class)

workflow() %>%
  add_recipe(cancer_recipe) %>%
  add_model(knn_spec) %>%
  fit_resamples(resamples = cancer_vfold) %>%
  collect_metrics()
## # A tibble: 2 x 5
##   .metric  .estimator  mean     n std_err
##   <chr>    <chr>      <dbl> <int>   <dbl>
## 1 accuracy binary     0.885    10  0.0104
## 2 roc_auc  binary     0.927    10  0.0111

7.4.2 Parameter value selection

Using 5- and 10-fold cross-validation, we have estimated that the prediction accuracy of our classifier is somewhere around 88%. Whether 88% is good or not depends entirely on the downstream application of the data analysis. In the present situation, we are trying to predict a tumour diagnosis, with expensive, damaging chemo/radiation therapy or patient death as potential consequences of misprediction. Hence, we’d like to do better than 88% for this application.

In order to improve our classifier, we have one choice of parameter: the number of neighbours, \(K\). Since cross-validation helps us evaluate the accuracy of our classifier, we can use cross-validation to calculate an accuracy for each value of \(K\) in a reasonable range, and then pick the value of \(K\) that gives us the best accuracy. The tidymodels package collection provides a very simple syntax for tuning models: each parameter in the model to be tuned should be specified as tune() in the model specification rather than given a particular value.

knn_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) %>%
  set_engine("kknn") %>%
  set_mode("classification")

Then instead of using fit or fit_resamples, we will use the tune_grid function to fit the model for each value in a range of parameter values. Here the grid = 10 argument specifies that the tuning should try 10 values of the number of neighbours \(K\) when tuning. We set the seed prior to tuning to ensure results are reproducible:

set.seed(1)
knn_results <- workflow() %>%
  add_recipe(cancer_recipe) %>%
  add_model(knn_spec) %>%
  tune_grid(resamples = cancer_vfold, grid = 10) %>%
  collect_metrics()
knn_results
## # A tibble: 20 x 6
##    neighbors .metric  .estimator  mean     n std_err
##        <int> <chr>    <chr>      <dbl> <int>   <dbl>
##  1         2 accuracy binary     0.864    10 0.0167 
##  2         2 roc_auc  binary     0.897    10 0.0127 
##  3         3 accuracy binary     0.885    10 0.0104 
##  4         3 roc_auc  binary     0.927    10 0.0111 
##  5         5 accuracy binary     0.890    10 0.0136 
##  6         5 roc_auc  binary     0.927    10 0.00960
##  7         6 accuracy binary     0.890    10 0.0136 
##  8         6 roc_auc  binary     0.930    10 0.0111 
##  9         7 accuracy binary     0.890    10 0.00966
## 10         7 roc_auc  binary     0.931    10 0.0106 
## 11         9 accuracy binary     0.887    10 0.0104 
## 12         9 roc_auc  binary     0.935    10 0.0132 
## 13        10 accuracy binary     0.887    10 0.0104 
## 14        10 roc_auc  binary     0.934    10 0.0130 
## 15        12 accuracy binary     0.882    10 0.0164 
## 16        12 roc_auc  binary     0.940    10 0.0113 
## 17        13 accuracy binary     0.885    10 0.0170 
## 18        13 roc_auc  binary     0.939    10 0.0118 
## 19        15 accuracy binary     0.873    10 0.0153 
## 20        15 roc_auc  binary     0.940    10 0.0103

We can select the best value of the number of neighbours (i.e., the one that results in the highest classifier accuracy estimate) by plotting the accuracy versus \(K\):

accuracies <- knn_results %>%
  filter(.metric == "accuracy")

accuracy_vs_k <- ggplot(accuracies, aes(x = neighbors, y = mean)) +
  geom_point() +
  geom_line() +
  labs(x = "Neighbors", y = "Accuracy Estimate")
accuracy_vs_k
Plot of accuracy estimate versus number of neighbours

Figure 7.5: Plot of accuracy estimate versus number of neighbours

This visualization suggests that \(K = 7\) provides the highest accuracy. But as you can see, there is no exact or perfect answer here; any selection between \(K = 3\) and \(13\) would be reasonably justified, as all of these differ in classifier accuracy by less than 1%. Remember: the values you see on this plot are estimates of the true accuracy of our classifier. Although the \(K=7\) value is higher than the others on this plot, that doesn’t mean the classifier is actually more accurate with this parameter value! Generally, when selecting \(K\) (and other parameters for other predictive models), we are looking for a value where:

  • we get roughly optimal accuracy, so that our model will likely be accurate
  • changing the value to a nearby one (e.g. from \(K=7\) to 6 or 8) doesn’t decrease accuracy too much, so that our choice is reliable in the presence of uncertainty
  • the cost of training the model is not prohibitive (e.g., in our situation, if \(K\) is too large, predicting becomes expensive!)

7.4.3 Under/overfitting

To build a bit more intuition, what happens if we keep increasing the number of neighbours \(K\)? In fact, the accuracy actually starts to decrease! Rather than setting grid = 10 and letting tidymodels decide what values of \(K\) to try, let’s specify the values explicitly by creating a data frame with a neighbors variable. Take a look as the plot below as we vary \(K\) from 1 to almost the number of observations in the data set:

set.seed(1)
k_lots <- tibble(neighbors = seq(from = 1, to = 385, by = 10))
knn_results <- workflow() %>%
  add_recipe(cancer_recipe) %>%
  add_model(knn_spec) %>%
  tune_grid(resamples = cancer_vfold, grid = k_lots) %>%
  collect_metrics()

accuracies <- knn_results %>%
  filter(.metric == "accuracy")

accuracy_vs_k_lots <- ggplot(accuracies, aes(x = neighbors, y = mean)) +
  geom_point() +
  geom_line() +
  labs(x = "Neighbors", y = "Accuracy Estimate")
accuracy_vs_k_lots
Plot of accuracy estimate versus number of neighbours for many K values

Figure 7.6: Plot of accuracy estimate versus number of neighbours for many K values

Underfitting: What is actually happening to our classifier that causes this? As we increase the number of neighbours, more and more of the training observations (and those that are farther and farther away from the point) get a “say” in what the class of a new observation is. This causes a sort of “averaging effect” to take place, making the boundary between where our classifier would predict a tumour to be malignant versus benign to smooth out and become simpler. If you take this to the extreme, setting \(K\) to the total training data set size, then the classifier will always predict the same label regardless of what the new observation looks like. In general, if the model isn’t influenced enough by the training data, it is said to underfit the data.

Overfitting: In contrast, when we decrease the number of neighbours, each individual data point has a stronger and stronger vote regarding nearby points. Since the data themselves are noisy, this causes a more “jagged” boundary corresponding to a less simple model. If you take this case to the extreme, setting \(K = 1\), then the classifier is essentially just matching each new observation to its closest neighbour in the training data set. This is just as problematic as the large \(K\) case, because the classifier becomes unreliable on new data: if we had a different training set, the predictions would be completely different. In general, if the model is influenced too much by the training data, it is said to overfit the data.

You can see this effect in the plots below as we vary the number of neighbours \(K\) in (1, 7, 20, 200):

Effect of K in overfitting and underfitting

Figure 7.7: Effect of K in overfitting and underfitting

7.5 Splitting data

Shuffling: When we split the data into train, test, and validation sets, we make the assumption that there is no order to our originally collected data set. However, if we think that there might be some order to the original data set, then we can randomly shuffle the data before splitting it. The tidymodels function initial_split and vfold_cv functions do this for us.

Stratification: If the data are imbalanced, we also need to be extra careful about splitting the data to ensure that enough of each class ends up in each of the train, validation, and test partitions. The strata argument in the initial_split and vfold_cv functions handles this for us too.

7.6 Summary

Classification algorithms use one or more quantitative variables to predict the value of a third, categorical variable. The K-nearest neighbour algorithm in particular does this by first finding the K points in the training data nearest to the new observation, and then returning the majority class vote from those training observations. We can evaluate a classifier by splitting the data randomly into a training and test data set, using the training set to build the classifier, and using the test set to estimate its accuracy. To tune the classifier (e.g., select the K in K-nearest neighbours), we maximize accuracy estimates from cross-validation.

Overview of K-nn classification

Figure 7.8: Overview of K-nn classification

The overall workflow for performing K-nearest neighbour classification using tidymodels is as follows:

  1. Use the initial_split function to split the data into a training and test set. Set the strata argument to the target variable. Put the test set aside for now.
  2. Use the vfold_cv function to split up the training data for cross validation.
  3. Create a recipe that specifies the target and predictor variables, as well as preprocessing steps for all variables. Pass the training data as the data argument of the recipe.
  4. Create a nearest_neighbors model specification, with neighbors = tune().
  5. Add the recipe and model specification to a workflow(), and use the tune_grid function on the train/validation splits to estimate the classifier accuracy for a range of \(K\) values.
  6. Pick a value of \(K\) that yields a high accuracy estimate that doesn’t change much if you change \(K\) to a nearby value.
  7. Make a new model specification for the best parameter value, and retrain the classifier using the fit function.
  8. Evaluate the estimated accuracy of the classifier on the test set using the predict function.

Strengths:

  1. Simple and easy to understand
  2. No assumptions about what the data must look like
  3. Works easily for binary (two-class) and multi-class (> 2 classes) classification problems

Weaknesses:

  1. As data gets bigger and bigger, K-nearest neighbour gets slower and slower, quite quickly
  2. Does not perform well with a large number of predictors
  3. Does not perform well when classes are imbalanced (when many more observations are in one of the classes compared to the others)