Sem Spirit

k-Nearest-Neighbors Classification in R

The following shows how to write a R script in order to classify using the K Nearest Neighbors method whether a patient will survive or died within 5 years after a breast cancer diagnosis according to his age and the number of axillary nodes.

We start by setting the working directory and loading the dataset :

#loading the dataset
setwd("D:/WORK/DEEPLEARNING/SITE/Machine-Learning/03_Classification/2_k-Nearest-Neighbors/DEV")
dataset = read.csv('dataset.csv')

The ‘dataset’ is an array of 306 records whose 40 first rows are :

Dataset in R of patients that survived or died within 5 years after a breast cancer diagnosis according to his age and the number of axillary nodes.

Dataset in R of patients that survived or died within 5 years after a breast cancer diagnosis according to his age and the number of axillary nodes.

Then we divide the dataset into the training and test sets :

#install.packages("caTools")
library(caTools)
set.seed(123)
split = sample.split(dataset$DEATH_WITHIN_5_YEARS, SplitRatio=0.75)
training_set = subset(dataset, split == TRUE)
test_set = subset(dataset, split == FALSE)

Then we can fit our k Nearest Neighbors Classifier on the training data and computing the predictions on the training and test sets :

library(class)
y_train_pred = knn(
train = training_set[,-3],
test = training_set[,-3],
cl = training_set[,3],
k=5)

y_test_pred = knn(
train = training_set[,-3],
test = test_set[,-3],
cl = training_set[,3],
k=5)

In order to measure the quality of the classification, we use a method called the « confusion matrix ». Each column of the matrix represents the number of occurrences of an estimated class, while each row represents the number of occurrences of a real (or reference) class. One of the advantages of the confusion matrix is that it quickly shows if the classifier works correctly.

The following script builds the confusion matrix according to the test set and to the training set :

cm_train = table(training_set[, 3], y_train_pred)
cm_test = table(test_set[, 3], y_test_pred)

print(cm_train)
print(cm_test)

The result is the following arrays :

Confusion matrix for the test set predictions (and for the training set predictions).

Confusion matrix for the test set predictions (and for the training set predictions).

The sum of all the values of each matrix gives the total number of records respectively in the test set (77 records) or in the training set (229 records).
The two rows of the each confusion matrix are interpreted as follows :
– among the 56 actual test-set alive patients (resp. 169 training-set alive patients), 53 (resp. 156) are classified as such and 3 (resp. 13) are wrongly classified as dead
– among the 20 actual test-set dead patients (resp. 61 training-set dead patients), 8 (resp. 20) are classified as such and 12 (resp. 41) are wrongly classified as alive

The two columns are interpreted as follows :
– among the 65 records classified as alive in the test set (resp. 197 in the training set), 53 (resp. 156) are correct (recorded as alive in the data) and 12 (resp. 41) are actually dead
– among the 11 records classified as dead in the test set (resp. 33 in the training set), 8 (resp. 20) are correct (recorded as dead in the data) and 3 (resp. 13) are actually alive

In conclusion, each confusion matrix lead to :
– 79% of the records were successfully classified in the test set
– and 77% classification success in the training set

Finally, with the following script, we show in a graph the actual training set observations (dots) and predictions (coloured area and slope) for ‘x’ years old patients that survived (green color) or died (red color) 5 years after having been diagnosed with a number ‘y’ of axillary nodes :

# Visualising the Training set results
mygraph(training_set=training_set, set_to_plot=training_set, xlabel='Patient Age', xstep=0.1, ylabel='Number of Axillary Nodes', ystep=0.1, title='Training set observations (dots) and predictions (coloured areas) for \'x\' years old patients that survived (green color) or died (red color) 5 years after having been diagnosed with a number \'y\' of axillary nodes.')

This script displays the following graph :

Training set observations (dots) and predictions (coloured areas) for x years old patients that survived (green color) or died (red color) 5 years after having been diagnosed with a number y of axillary nodes.

Training set observations (dots) and predictions (coloured areas) for x years old patients that survived (green color) or died (red color) 5 years after having been diagnosed with a number y of axillary nodes.

A similar graph based on the test set observations can be displayed with the following script :

# Visualising the Test set results
mygraph(training_set=training_set, set_to_plot=test_set, xlabel='Patient Age', xstep=0.1, ylabel='Number of Axillary Nodes', ystep=0.1, title='Test set observations (dots) and predictions (coloured areas) for \'x\' years old patients that survived (green color) or died (red color) 5 years after having been diagnosed with a number \'y\' of axillary nodes.')

Which gives the following graph :

Test set observations (dots) and predictions (coloured areas) for x years old patients that survived (green color) or died (red color) 5 years after having been diagnosed with a number y of axillary nodes.

Test set observations (dots) and predictions (coloured areas) for x years old patients that survived (green color) or died (red color) 5 years after having been diagnosed with a number y of axillary nodes.

We can see that the classifier is working relatively well since most of the alive patients are correctly classified as alive (green) and a non negligeable part of dead patients have been correctly classified as dead.

The mygraph(…) function is a built-in function defined as follows :

mygraph <- function(training_set, set_to_plot, xlabel, xstep, ylabel, ystep, title){
X1 = seq(min(set_to_plot[, 1]) - 1, max(set_to_plot[, 1]) + 1, by = xstep)
X2 = seq(min(set_to_plot[, 2]) - 1, max(set_to_plot[, 2]) + 1, by = ystep)
grid_set = expand.grid(X1, X2)
colnames(grid_set) = c(xlabel, ylabel)
y_grid = knn(
train = training_set[,-3],
test = grid_set,
cl = training_set[,3],
k=5)
plot(set_to_plot[, -3],
main = title,
xlab = xlabel, ylab = ylabel,
xlim = range(X1), ylim = range(X2))
contour(X1, X2, matrix(as.numeric(y_grid), length(X1), length(X2)), add = TRUE)
points(grid_set, pch = '.', col = ifelse(y_grid == 1, 'seagreen4', 'palevioletred'))
points(set_to_plot, pch = 21, bg = ifelse(set_to_plot[, 3] == 1, 'darkgreen', 'darkred'))
}