library(qrcode)
plot(qr_code("https://rpubs.com/esuess/kNN"))
Palmer Penguins, visualization, ML Classification, kNN, Confussion Matrix, Accuracy
Welcome prospective Data Science Students!
Welcome students!
This is a quick tutorial on how to visualize the Palmer Penguins dataset, and then use kNN to classify the species of penguins. We will then use a confusion matrix to evaluate the accuracy of our model.
QR Code for this page
Load the R libraries we will be using.
library(palmerpenguins)
library(DT)
library(gt)
library(naniar)
# library(devtools)
# devtools::install_github("cmartin/ggConvexHull")
library(ggConvexHull)
library(tidyverse)
library(plotly)
library(tidymodels)
library(yardstick)
Load the data
We drop the two categorical variables, island
and sex
. We will use the species
variable as our response variable.
data(penguins)
datatable(penguins)
<- penguins |> select(-c("island","sex"))
penguins datatable(penguins)
How many penguins are there?
|> select(species) |>
penguins group_by(species) |>
count() |>
pivot_wider(names_from = species, values_from = n) |>
gt()
Adelie | Chinstrap | Gentoo |
---|---|---|
152 | 68 | 124 |
How many missing values are there?
vis_miss(penguins)
n_var_miss(penguins)
[1] 4
gg_miss_var(penguins)
library(skimr)
skim(penguins)
Name | penguins |
Number of rows | 344 |
Number of columns | 5 |
_______________________ | |
Column type frequency: | |
factor | 1 |
numeric | 4 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
species | 0 | 1 | FALSE | 3 | Ade: 152, Gen: 124, Chi: 68 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
bill_length_mm | 2 | 0.99 | 43.92 | 5.46 | 32.1 | 39.23 | 44.45 | 48.5 | 59.6 | ▃▇▇▆▁ |
bill_depth_mm | 2 | 0.99 | 17.15 | 1.97 | 13.1 | 15.60 | 17.30 | 18.7 | 21.5 | ▅▅▇▇▂ |
flipper_length_mm | 2 | 0.99 | 200.92 | 14.06 | 172.0 | 190.00 | 197.00 | 213.0 | 231.0 | ▂▇▃▅▂ |
body_mass_g | 2 | 0.99 | 4201.75 | 801.95 | 2700.0 | 3550.00 | 4050.00 | 4750.0 | 6300.0 | ▃▇▆▃▂ |
Drop the missing values
We will be using the kNN algorithm, so we need to remove the rows of data with missing values.
<- penguins |> drop_na()
penguins datatable(penguins)
Visualize the data
|> ggplot(aes(x = bill_length_mm, y = bill_depth_mm)) + geom_point() penguins
Which species is this penguin?
This is a? Wikipedia
|> ggplot(aes(x = bill_length_mm, y = bill_depth_mm, color = species)) + geom_point() penguins
|> ggplot(aes(x = bill_length_mm, y = bill_depth_mm, color = species)) + geom_point() + facet_wrap(~species) penguins
<- penguins |> ggplot(aes(x = bill_length_mm, y = bill_depth_mm)) +
peng_convex geom_point() +
geom_convexhull(alpha = 0.3, aes(fill = species))
peng_convex
|> ggplotly() peng_convex
Split the data into training and testing sets
When applying Machine Learning we start by randomly splitting the data into a training set and a testing set. We will use the training set to build our model, and then use the testing set to evaluate the accuracy of our model.
set.seed(123)
<- initial_split(penguins, prop = 0.8, strata = species)
penguin_split <- training(penguin_split)
penguin_train <- testing(penguin_split)
penguin_test
datatable(penguin_train)
datatable(penguin_test)
Build a kNN model for Classification
The \(k\) nearest neighbor model kNN is a simple model that classifies a new observation by finding the \(k\) closest observations in the training set, and then classifying the new observation by the majority vote of the k closest observations. The kNN model is a non-parametric model, meaning that it does not assume a particular distribution for the data. The kNN model is a lazy learner, meaning that it does not build a model, but rather stores the training data, and then uses the training data to classify new observations. The kNN model is a simple model, and is often used as a baseline model to compare to more complex models.
The kNN model measure distances. How do we measure distance in one dimension? We use absolute value.
\[d(x,y) = |x-y|\] How do we measure distance in two dimensions? We use the Euclidean distance.
\[d(x,y) = \sqrt{(x_1-y_1)^2 + (x_2-y_2)^2}\]
How do we measure distance in \(p\) dimensions?
\[d(x,y) = \sqrt{(x_1-y_1)^2 + (x_2-y_2)^2 + \cdots + (x_p-y_p)^2}\]
The kNN model
We use the kNN model for classification with the training data.
# Using the "rectangular" weight function is the same as unweighted kNN
<- nearest_neighbor(weight_func = "rectangular", neighbors = 4) |>
knn_model set_mode("classification") |>
set_engine("kknn") |>
fit(species ~ ., data = penguin_train)
knn_model
parsnip model object
Call:
kknn::train.kknn(formula = species ~ ., data = data, ks = min_rows(4, data, 5), kernel = ~"rectangular")
Type of response variable: nominal
Minimal misclassification: 0.01470588
Best kernel: rectangular
Best k: 4
Make predictions on the training set
<- predict(knn_model, penguin_train) |>
knn_predictions bind_cols(penguin_train)
Confusion Matrix, count the number of correctly classified penguins on the training set
<- conf_mat(knn_predictions, truth = species, estimate = .pred_class)
conf_m conf_m
Truth
Prediction Adelie Chinstrap Gentoo
Adelie 119 2 0
Chinstrap 1 52 0
Gentoo 0 0 98
autoplot(conf_m, type = "heatmap")
Accuracy of the model on the training data
accuracy(knn_predictions, truth = species, estimate = .pred_class)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy multiclass 0.989
Make predictions on the testing set
<- predict(knn_model, penguin_test) |>
knn_predictions bind_cols(penguin_test)
Confusion Matrix, count the number of correctly classified penguins, on the testing set
<- conf_mat(knn_predictions, truth = species, estimate = .pred_class)
conf_m conf_m
Truth
Prediction Adelie Chinstrap Gentoo
Adelie 31 1 0
Chinstrap 0 13 0
Gentoo 0 0 25
autoplot(conf_m, type = "heatmap")
Accuracy of the model on the testing data
accuracy(knn_predictions, truth = species, estimate = .pred_class)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy multiclass 0.986
Summary
- We have loaded some data on penguins into R.
- We have visualized the data.
- We have cleaned the data.
- We have split the data into training and testing sets.
- We have neglected to scale or normalize the data. (See this second analysis that normalizes the data and uses cross-validation to tune the model to pick the best \(k\). kNN2)
- We have built a kNN model for classification.
- We have evaluated the accuracy of the model on the training set.
- We have evaluated the accuracy of the model on the testing set.