Create your Machine Learning library from scratch with R ! (3/5) – KNN


This is this second post of the “Create your Machine Learning library from scratch with R !” series. Today, we will see how you can implement K nearest neighbors (KNN) using only the linear algebra available in R. Previously, we managed to implement PCA and next time we will deal with SVM and decision trees.

The K-nearest neighbors (KNN) is a simple yet efficient classification and regression algorithm. KNN assumes that an observation will be similar to its K closest neighbors. For instance, if most of the neighbors of a given point belongs to a given class, it seems reasonable to assume that the point will belong to the same given class.

The mathematics of KNN

Now, let’s quickly derive the mathematics used for KNN regression (they are similar for classification).

Let \mathbf{x}_1, ... \mathbf{x}_n be the observations of our training dataset. The points are in \mathbb{R}^{p}. We denote y_1, ..., y_n the variable we seek to estimate. We know its value for the train dataset.
Let \mathbf{x}_{n+1} be a new point in \mathbb{R}^{p}. We do not know y_{n+1} and will estimate it using our train dataset.

Let k be a positive and non-zero integer (the number of neighbors used for estimation). We want to select the k points from the dataset which are the closest to \mathbf{x}_{n+1}. To do so, we compute the euclidean distance d_i=||\mathbf{x}_i-\mathbf{x}_{n+1}||_{L2}. From all the distance, we can compute D_k, the smallest radius of the circle centered on \mathbf{x}_{n+1} which includes exactly k points from the training sample.

An estimation \hat{y}_{n+1} of y_{n+1} is now easy to construct. This is the mean of the y_i of the k closest points to \mathbf{x}_{n+1}:

    \[\hat{y}_{n+1} = \frac{1}{k} \sum_{i\leq n} y_i 1_{d_i \leq D_k}\]

KNN regression in R

First, we build a “my_knn_regressor” object which stores all the training points, the value of the target variable and the number of neighbors to use.

###Nearest neighbors
my_knn_regressor = function(x,y,k=5)
  if (!is.matrix(x))
    x = as.matrix(x)
  if (!is.matrix(y))
    y = as.matrix(y)
  my_knn = list()
  my_knn[['points']] = x
  my_knn[['value']] = y
  my_knn[['k']] = k
  attr(my_knn, "class") = "my_knn_regressor"

The tricky part of KNN is to compute efficiently the distance. We will use the function we created in our previous post on vectorization. The function and mathematical derivations are specified in this post.

  xn = rowSums(X ** 2)
  yn = rowSums(Y ** 2)
  outer(xn, yn, '+') - 2 * tcrossprod(X, Y)

Now we can build our predictor:

predict.my_knn_regressor = function(my_knn,x)
  if (!is.matrix(x))
    x = as.matrix(x)
  ##Compute pairwise distance
  dist_pair = compute_pairwise_distance(x,my_knn[['points']])
  ##as.matrix(apply(dist_pair,2,order)<=my_knn[['k']]) orders the points by distance and select the k-closest points
  ##The M[i,j]=1 if x_j is on the k closest point to x_i
  crossprod(apply(dist_pair,1,order) <= my_knn[['k']], my_knn[["value"]]) / my_knn[['k']]

The last line may seem complicated:

  1. apply(dist_pair,2,order) orders the points by distance
  2. apply(dist_pair,2,order)<=my_knn[['k']] selects the k-closest points to each point in our new dataset
  3. M=t(as.matrix(apply(dist_pair,2,order) <= my_knn[['k']])) cast the matrix into a one hot matrix. \mathbf{M}_{i,j}=1 if \mathbf{x}_j is one of the k closest points to \mathbf{x}_i. \mathbf{M}_{i,j} is zero otherwise.
  4. M %*% my_knn[['value']] / my_knn sums the value of the k closest points and normalises it by k

KNN Binary Classification in R

The previous code can be reused as it is for binary classification. Your outcome should be encoded as a one-hot variable. If the estimated output is greater (resp. less) than 0.5, you can assume that your point belongs to the class encoded as one (resp. zero). We will use the classical Iris dataset and classify the setosa versus the virginica specy.

iris_class = iris[iris[["Species"]]!="versicolor",]
iris_class[["Species"]] = iris_class[["Species"]] != "setosa"
knn_class = my_knn_regressor(iris_class[,1:2], as.numeric(iris_class[,5]))
predict(knn_class, iris_class[,1:2])

Since, we only used 2 variables, we can easily plot the decision boundaries on a 2D plot.

#Build grid
x_coord = seq(min(iris_class[,1]) - 0.2,max(iris_class[,1]) + 0.2,length.out = 200)
y_coord = seq(min(iris_class[,2])- 0.2,max(iris_class[,2]) + 0.2 , length.out = 200)
coord = expand.grid(x = x_coord, y = y_coord)
#predict probabilities
coord[['prob']] = predict(knn_class, coord[,1:2])

ggplot() + 
  ##Ad tiles according to probabilities
  geom_tile(data=coord,mapping=aes(x, y, fill=prob)) + scale_fill_gradient(low = "lightblue", high = "red") +
  ##add points
  geom_point(data=iris_class,mapping=aes(Sepal.Length,Sepal.Width, shape=Species),size=3 ) + 
  #add the labels to the plots
  xlab('Sepal length') + ylab('Sepal width') + ggtitle('Decision boundaries of KNN')+
  #remove grey border from the tile

And this gives us this cool plot:

Possible extensions

Our current KNN is basic, but you can improve and test it in several ways:

  • What is the influence of the number of neighbors ? (You should see some overfitting/underfitting)
  • Can you implement other metrics than L_2 distance ? Can you create kernel KNNs ?
  • Instead of doing estimations using only the mean, could you use a more complex mapping ?

Thanks for reading ! To find more posts on Machine Learning, Python and R, you can follow us on Facebook or Twitter.


  1. Question about the codes:
    The last line in the paragraph “#Build grid” does not work, and get the message:

    Error in t(as.matrix(apply(dist_pair, 2, order) <= my_knn[["k"]])) %*% :
    non-conformable arguments

    #Build grid
    x_coord <- seq(min(iris_class[,1]) – 0.2,max(iris_class[,1]) + 0.2,length.out = 200)
    y_coord <– seq(min(iris_class[,2])- 0.2,max(iris_class[,2]) + 0.2 , length.out = 200)
    coord <- expand.grid(x=x_coord, y=y_coord)
    #predict probabilities
    coord[['prob']] <- predict(knn_class,coord[,1:2]) #### THIS IS THE LINE DOES NOT WORK

    Could you please give us some hints? Thanks!

    • They were some typos due to the encoding of the page, it’s corrected. Do you manage to make the corrected code run ?


Please enter your comment!
Please enter your name here