# install.packages("keras3")Building a simple neural network using Keras and Tensorflow
Thank you
A big thank you to Leon Jessen for posting his code on github.
Building a simple neural network using Keras and Tensorflow
I have forked his project on github and put his code into an R Notebook so we can run it in class.
Motivation
The following is a minimal example for building your first simple artificial neural network using Keras and TensorFlow for R.
Getting started - Install Keras and TensorFlow for R
You can install the Keras for R package from CRAN as follows:
TensorFlow is the default backend engine. TensorFlow and Keras can be installed as follows:
# library(keras3)
# install_keras()Naturally, we will also need Tidyverse:
# install.packages("tidyverse")Once installed, we simply load the libraries
library("keras3")
library("tidyverse")── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr 1.1.4 ✔ readr 2.1.5
✔ forcats 1.0.1 ✔ stringr 1.5.2
✔ ggplot2 4.0.2 ✔ tibble 3.3.0
✔ lubridate 1.9.4 ✔ tidyr 1.3.1
✔ purrr 1.2.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
Artificial Neural Network Using the Iris Data Set
Right, let’s get to it!
Data
The famous (Fisher’s or Anderson’s) iris data set contains a total of 150 observations of 4 input features Sepal.Length, Sepal.Width, Petal.Length and Petal.Width and 3 output classes setosa versicolor and virginica, with 50 observations in each class. The distributions of the feature values looks like so:
iris |>
as_tibble() |>
pivot_longer(
cols = -Species,
names_to = "feature",
values_to = "value"
) |>
ggplot(aes(x = feature, y = value, fill = Species)) +
geom_violin(alpha = 0.5, scale = "width") +
theme_bw()Our aim is to connect the 4 input features to the correct output class using an artificial neural network. For this task, we have chosen the following simple architecture with one input layer with 4 neurons (one for each feature), one hidden layer with 4 neurons and one output layer with 3 neurons (one for each class), all fully connected:
Our artificial neural network will have a total of 35 parameters: 4 for each input neuron connected to the hidden layer, plus an additional 4 for the associated first bias neuron and 3 for each of the hidden neurons connected to the output layer, plus an additional 3 for the associated second bias neuron. I.e. \(4 \times 4+4+4 \ times 3+3=35\)
Prepare data
We start with slightly wrangling the iris data set by renaming and scaling the features and converting character labels to numeric:
set.seed(265509)
nn_dat <- iris |>
as_tibble() |>
mutate(sepal_length = scale(Sepal.Length),
sepal_width = scale(Sepal.Width),
petal_length = scale(Petal.Length),
petal_width = scale(Petal.Width),
class_label = as.numeric(Species) - 1) |>
select(sepal_length, sepal_width, petal_length, petal_width, class_label)
nn_dat |> head(3)# A tibble: 3 × 5
sepal_length[,1] sepal_width[,1] petal_length[,1] petal_width[,1] class_label
<dbl> <dbl> <dbl> <dbl> <dbl>
1 -0.898 1.02 -1.34 -1.31 0
2 -1.14 -0.132 -1.34 -1.31 0
3 -1.38 0.327 -1.39 -1.31 0
Then, we create indices for splitting the iris data into a training and a test data set. We set aside 20% of the data for testing:
test_fraction <- 0.20
n_total_samples <- nrow(nn_dat)
n_train_samples <- ceiling((1 - test_fraction) * n_total_samples)
train_indices <- sample(n_total_samples, n_train_samples)
n_test_samples <- n_total_samples - n_train_samples
test_indices <- setdiff(seq(1, n_train_samples), train_indices)Based on the indices, we can now create training and test data
x_train <- nn_dat |>
select(-class_label) |>
as.matrix() |>
(\(m) m[train_indices, ])()
y_train <- nn_dat |>
slice(train_indices) |>
pull(class_label) |>
to_categorical(num_classes = 3)
x_test <- nn_dat |>
select(-class_label) |>
as.matrix() |>
(\(m) m[test_indices, ])()
y_test <- nn_dat |>
slice(test_indices) |>
pull(class_label) |>
to_categorical(num_classes = 3)Set Architecture
With the data in place, we now set the architecture of our artifical neural network:
model <- keras_model_sequential()
model |>
layer_dense(units = 4, activation = 'relu', input_shape = 4) |>
layer_dense(units = 3, activation = 'softmax')
model |> summary()Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense) │ (None, 4) │ 20 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_1 (Dense) │ (None, 3) │ 15 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
Total params: 35 (140.00 B)
Trainable params: 35 (140.00 B)
Non-trainable params: 0 (0.00 B)
Next, the architecture set in the model needs to be compiled:
model |> compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy')
)Train the Artificial Neural Network
Lastly we fit the model and save the training progres in the history object:
history <- model |> fit(
x = x_train, y = y_train,
epochs = 200,
batch_size = 20,
validation_split = 0
)Epoch 1/200
6/6 - 1s - 129ms/step - accuracy: 0.2750 - loss: 1.1042
Epoch 2/200
6/6 - 0s - 6ms/step - accuracy: 0.3167 - loss: 1.0780
Epoch 3/200
6/6 - 0s - 6ms/step - accuracy: 0.3083 - loss: 1.0594
Epoch 4/200
6/6 - 0s - 7ms/step - accuracy: 0.3333 - loss: 1.0434
Epoch 5/200
6/6 - 0s - 6ms/step - accuracy: 0.3333 - loss: 1.0288
Epoch 6/200
6/6 - 0s - 5ms/step - accuracy: 0.4167 - loss: 1.0152
Epoch 7/200
6/6 - 0s - 5ms/step - accuracy: 0.4500 - loss: 1.0023
Epoch 8/200
6/6 - 0s - 7ms/step - accuracy: 0.4583 - loss: 0.9902
Epoch 9/200
6/6 - 0s - 6ms/step - accuracy: 0.4583 - loss: 0.9788
Epoch 10/200
6/6 - 0s - 5ms/step - accuracy: 0.4833 - loss: 0.9676
Epoch 11/200
6/6 - 0s - 5ms/step - accuracy: 0.5250 - loss: 0.9570
Epoch 12/200
6/6 - 0s - 6ms/step - accuracy: 0.5667 - loss: 0.9466
Epoch 13/200
6/6 - 0s - 6ms/step - accuracy: 0.5833 - loss: 0.9363
Epoch 14/200
6/6 - 0s - 6ms/step - accuracy: 0.6000 - loss: 0.9262
Epoch 15/200
6/6 - 0s - 6ms/step - accuracy: 0.5750 - loss: 0.9169
Epoch 16/200
6/6 - 0s - 7ms/step - accuracy: 0.5667 - loss: 0.9076
Epoch 17/200
6/6 - 0s - 6ms/step - accuracy: 0.5750 - loss: 0.8990
Epoch 18/200
6/6 - 0s - 5ms/step - accuracy: 0.6000 - loss: 0.8901
Epoch 19/200
6/6 - 0s - 5ms/step - accuracy: 0.6167 - loss: 0.8821
Epoch 20/200
6/6 - 0s - 5ms/step - accuracy: 0.6500 - loss: 0.8741
Epoch 21/200
6/6 - 0s - 5ms/step - accuracy: 0.6583 - loss: 0.8665
Epoch 22/200
6/6 - 0s - 6ms/step - accuracy: 0.6750 - loss: 0.8592
Epoch 23/200
6/6 - 0s - 6ms/step - accuracy: 0.6750 - loss: 0.8520
Epoch 24/200
6/6 - 0s - 5ms/step - accuracy: 0.6833 - loss: 0.8451
Epoch 25/200
6/6 - 0s - 5ms/step - accuracy: 0.6833 - loss: 0.8380
Epoch 26/200
6/6 - 0s - 5ms/step - accuracy: 0.6833 - loss: 0.8312
Epoch 27/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.8246
Epoch 28/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.8180
Epoch 29/200
6/6 - 0s - 5ms/step - accuracy: 0.7000 - loss: 0.8117
Epoch 30/200
6/6 - 0s - 5ms/step - accuracy: 0.7000 - loss: 0.8055
Epoch 31/200
6/6 - 0s - 6ms/step - accuracy: 0.7167 - loss: 0.7994
Epoch 32/200
6/6 - 0s - 6ms/step - accuracy: 0.7250 - loss: 0.7933
Epoch 33/200
6/6 - 0s - 6ms/step - accuracy: 0.7333 - loss: 0.7874
Epoch 34/200
6/6 - 0s - 6ms/step - accuracy: 0.7333 - loss: 0.7815
Epoch 35/200
6/6 - 0s - 7ms/step - accuracy: 0.7333 - loss: 0.7756
Epoch 36/200
6/6 - 0s - 7ms/step - accuracy: 0.7500 - loss: 0.7703
Epoch 37/200
6/6 - 0s - 7ms/step - accuracy: 0.7750 - loss: 0.7645
Epoch 38/200
6/6 - 0s - 7ms/step - accuracy: 0.7833 - loss: 0.7595
Epoch 39/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.7540
Epoch 40/200
6/6 - 0s - 7ms/step - accuracy: 0.7917 - loss: 0.7487
Epoch 41/200
6/6 - 0s - 6ms/step - accuracy: 0.7917 - loss: 0.7437
Epoch 42/200
6/6 - 0s - 6ms/step - accuracy: 0.8000 - loss: 0.7392
Epoch 43/200
6/6 - 0s - 6ms/step - accuracy: 0.8167 - loss: 0.7347
Epoch 44/200
6/6 - 0s - 8ms/step - accuracy: 0.8167 - loss: 0.7300
Epoch 45/200
6/6 - 0s - 7ms/step - accuracy: 0.8167 - loss: 0.7251
Epoch 46/200
6/6 - 0s - 7ms/step - accuracy: 0.8250 - loss: 0.7208
Epoch 47/200
6/6 - 0s - 7ms/step - accuracy: 0.8250 - loss: 0.7162
Epoch 48/200
6/6 - 0s - 7ms/step - accuracy: 0.8500 - loss: 0.7117
Epoch 49/200
6/6 - 0s - 6ms/step - accuracy: 0.8583 - loss: 0.7072
Epoch 50/200
6/6 - 0s - 6ms/step - accuracy: 0.8583 - loss: 0.7028
Epoch 51/200
6/6 - 0s - 6ms/step - accuracy: 0.8583 - loss: 0.6982
Epoch 52/200
6/6 - 0s - 5ms/step - accuracy: 0.8667 - loss: 0.6937
Epoch 53/200
6/6 - 0s - 5ms/step - accuracy: 0.8667 - loss: 0.6894
Epoch 54/200
6/6 - 0s - 5ms/step - accuracy: 0.8667 - loss: 0.6857
Epoch 55/200
6/6 - 0s - 5ms/step - accuracy: 0.8667 - loss: 0.6811
Epoch 56/200
6/6 - 0s - 6ms/step - accuracy: 0.8667 - loss: 0.6772
Epoch 57/200
6/6 - 0s - 6ms/step - accuracy: 0.8667 - loss: 0.6730
Epoch 58/200
6/6 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.6697
Epoch 59/200
6/6 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.6652
Epoch 60/200
6/6 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.6614
Epoch 61/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6577
Epoch 62/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6542
Epoch 63/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6504
Epoch 64/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6467
Epoch 65/200
6/6 - 0s - 6ms/step - accuracy: 0.8833 - loss: 0.6434
Epoch 66/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.6397
Epoch 67/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.6361
Epoch 68/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.6325
Epoch 69/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6293
Epoch 70/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.6257
Epoch 71/200
6/6 - 0s - 6ms/step - accuracy: 0.8833 - loss: 0.6225
Epoch 72/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6191
Epoch 73/200
6/6 - 0s - 6ms/step - accuracy: 0.8833 - loss: 0.6156
Epoch 74/200
6/6 - 0s - 7ms/step - accuracy: 0.8833 - loss: 0.6128
Epoch 75/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6091
Epoch 76/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.6062
Epoch 77/200
6/6 - 0s - 6ms/step - accuracy: 0.8833 - loss: 0.6026
Epoch 78/200
6/6 - 0s - 6ms/step - accuracy: 0.8917 - loss: 0.6004
Epoch 79/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.5963
Epoch 80/200
6/6 - 0s - 6ms/step - accuracy: 0.8917 - loss: 0.5935
Epoch 81/200
6/6 - 0s - 7ms/step - accuracy: 0.8917 - loss: 0.5902
Epoch 82/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5869
Epoch 83/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.5836
Epoch 84/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5803
Epoch 85/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5773
Epoch 86/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.5740
Epoch 87/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.5707
Epoch 88/200
6/6 - 0s - 7ms/step - accuracy: 0.9000 - loss: 0.5676
Epoch 89/200
6/6 - 0s - 7ms/step - accuracy: 0.9000 - loss: 0.5648
Epoch 90/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.5613
Epoch 91/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.5581
Epoch 92/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.5550
Epoch 93/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5520
Epoch 94/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5493
Epoch 95/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5455
Epoch 96/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5428
Epoch 97/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5399
Epoch 98/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5368
Epoch 99/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5336
Epoch 100/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5305
Epoch 101/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5272
Epoch 102/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5242
Epoch 103/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.5210
Epoch 104/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5178
Epoch 105/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5148
Epoch 106/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5118
Epoch 107/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5085
Epoch 108/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5055
Epoch 109/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5023
Epoch 110/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.4991
Epoch 111/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.4959
Epoch 112/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.4927
Epoch 113/200
6/6 - 0s - 7ms/step - accuracy: 0.9083 - loss: 0.4894
Epoch 114/200
6/6 - 0s - 6ms/step - accuracy: 0.9167 - loss: 0.4862
Epoch 115/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4833
Epoch 116/200
6/6 - 0s - 6ms/step - accuracy: 0.9167 - loss: 0.4798
Epoch 117/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4769
Epoch 118/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4740
Epoch 119/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4708
Epoch 120/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4676
Epoch 121/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4650
Epoch 122/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4619
Epoch 123/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4587
Epoch 124/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4557
Epoch 125/200
6/6 - 0s - 7ms/step - accuracy: 0.9250 - loss: 0.4528
Epoch 126/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4502
Epoch 127/200
6/6 - 0s - 7ms/step - accuracy: 0.9167 - loss: 0.4470
Epoch 128/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.4439
Epoch 129/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4408
Epoch 130/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4377
Epoch 131/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4352
Epoch 132/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4318
Epoch 133/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4289
Epoch 134/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4258
Epoch 135/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4227
Epoch 136/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.4197
Epoch 137/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4170
Epoch 138/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.4140
Epoch 139/200
6/6 - 0s - 7ms/step - accuracy: 0.9250 - loss: 0.4112
Epoch 140/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.4085
Epoch 141/200
6/6 - 0s - 7ms/step - accuracy: 0.9333 - loss: 0.4055
Epoch 142/200
6/6 - 0s - 7ms/step - accuracy: 0.9333 - loss: 0.4028
Epoch 143/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4000
Epoch 144/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.3969
Epoch 145/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.3941
Epoch 146/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.3913
Epoch 147/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.3885
Epoch 148/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.3857
Epoch 149/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.3828
Epoch 150/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3808
Epoch 151/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.3776
Epoch 152/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3750
Epoch 153/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3728
Epoch 154/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3700
Epoch 155/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3675
Epoch 156/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3647
Epoch 157/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3626
Epoch 158/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3600
Epoch 159/200
6/6 - 0s - 7ms/step - accuracy: 0.9500 - loss: 0.3576
Epoch 160/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3551
Epoch 161/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3523
Epoch 162/200
6/6 - 0s - 6ms/step - accuracy: 0.9500 - loss: 0.3504
Epoch 163/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3476
Epoch 164/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3449
Epoch 165/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3425
Epoch 166/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.3402
Epoch 167/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3377
Epoch 168/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3353
Epoch 169/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3337
Epoch 170/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3310
Epoch 171/200
6/6 - 0s - 6ms/step - accuracy: 0.9583 - loss: 0.3288
Epoch 172/200
6/6 - 0s - 7ms/step - accuracy: 0.9583 - loss: 0.3268
Epoch 173/200
6/6 - 0s - 6ms/step - accuracy: 0.9583 - loss: 0.3242
Epoch 174/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3222
Epoch 175/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3198
Epoch 176/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3176
Epoch 177/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3153
Epoch 178/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3129
Epoch 179/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3108
Epoch 180/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3082
Epoch 181/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3060
Epoch 182/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3036
Epoch 183/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.3015
Epoch 184/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.2993
Epoch 185/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.2972
Epoch 186/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.2948
Epoch 187/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.2928
Epoch 188/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.2907
Epoch 189/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2893
Epoch 190/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2867
Epoch 191/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2847
Epoch 192/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2828
Epoch 193/200
6/6 - 0s - 6ms/step - accuracy: 0.9667 - loss: 0.2809
Epoch 194/200
6/6 - 0s - 6ms/step - accuracy: 0.9667 - loss: 0.2788
Epoch 195/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2768
Epoch 196/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2747
Epoch 197/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2730
Epoch 198/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2708
Epoch 199/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2690
Epoch 200/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.2670
plot(history) +
ggtitle("Training a neural network based classifier on the iris data set") +
theme_bw()Evaluate Network Performance
The final performance can be obtained like so:
perf <- model |> evaluate(x_test, y_test)1/1 - 0s - 293ms/step - accuracy: 0.9565 - loss: 0.3085
print(perf)$accuracy
[1] 0.9565217
$loss
[1] 0.3084682
classes <- iris |>
as_tibble() |>
pull(Species) |>
unique()
y_pred <- model |>
predict(x_test) |>
op_argmax(axis = -1) |>
as.numeric() - 11/1 - 0s - 121ms/step
y_true <- nn_dat |>
slice(test_indices) |>
pull(class_label)
tibble(
y_true = classes[y_true + 1],
y_pred = classes[y_pred + 1],
Correct = factor(ifelse(y_true == y_pred, "Yes", "No"))
) |>
ggplot(aes(x = y_true, y = y_pred, colour = Correct)) +
geom_jitter() +
theme_bw() +
ggtitle(label = "Classification Performance of Artificial Neural Network",
subtitle = str_c("Accuracy = ",round(perf$accuracy,3)*100,"%")) +
xlab(label = "True iris class") +
ylab(label = "Predicted iris class")library(gmodels)
CrossTable(y_pred, y_true,
prop.chisq = FALSE, prop.t = FALSE, prop.r = FALSE,
dnn = c('predicted', 'actual'))
Cell Contents
|-------------------------|
| N |
| N / Col Total |
|-------------------------|
Total Observations in Table: 23
| actual
predicted | 0 | 1 | 2 | Row Total |
-------------|-----------|-----------|-----------|-----------|
0 | 12 | 0 | 0 | 12 |
| 1.000 | 0.000 | 0.000 | |
-------------|-----------|-----------|-----------|-----------|
1 | 0 | 6 | 1 | 7 |
| 0.000 | 1.000 | 0.200 | |
-------------|-----------|-----------|-----------|-----------|
2 | 0 | 0 | 4 | 4 |
| 0.000 | 0.000 | 0.800 | |
-------------|-----------|-----------|-----------|-----------|
Column Total | 12 | 6 | 5 | 23 |
| 0.522 | 0.261 | 0.217 | |
-------------|-----------|-----------|-----------|-----------|
Conclusion
I hope this illustrated just how easy it is to get started building artificial neural network using Keras and TensorFlow in R. With relative ease, we created a 3-class predictor with an accuracy of 100%. This was a basic minimal example. The network can be expanded to create Deep Learning networks and also the entire TensorFlow API is available.
Enjoy and Happy Learning!
Leon
Thanks again Leon, this was awsome!!!