# install.packages("keras3")Building a simple neural network using Keras and Tensorflow
Thank you
A big thank you to Leon Jessen for posting his code on github.
Building a simple neural network using Keras and Tensorflow
I have forked his project on github and put his code into an R Notebook so we can run it in class.
Motivation
The following is a minimal example for building your first simple artificial neural network using Keras and TensorFlow for R.
Getting started - Install Keras and TensorFlow for R
You can install the Keras for R package from CRAN as follows:
TensorFlow is the default backend engine. TensorFlow and Keras can be installed as follows:
# library(keras3)
# install_keras(backend = "tensorflow")Naturally, we will also need Tidyverse:
# install.packages("tidyverse")Once installed, we simply load the libraries
library("keras3")
library("tidyverse")── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr 1.1.4 ✔ readr 2.1.5
✔ forcats 1.0.1 ✔ stringr 1.5.2
✔ ggplot2 4.0.2 ✔ tibble 3.3.0
✔ lubridate 1.9.4 ✔ tidyr 1.3.1
✔ purrr 1.2.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
Artificial Neural Network Using the Iris Data Set
Right, let’s get to it!
Data
The famous (Fisher’s or Anderson’s) iris data set contains a total of 150 observations of 4 input features Sepal.Length, Sepal.Width, Petal.Length and Petal.Width and 3 output classes setosa versicolor and virginica, with 50 observations in each class. The distributions of the feature values looks like so:
iris |>
as_tibble() |>
pivot_longer(
cols = -Species,
names_to = "feature",
values_to = "value"
) |>
ggplot(aes(x = feature, y = value, fill = Species)) +
geom_violin(alpha = 0.5, scale = "width") +
theme_bw()Our aim is to connect the 4 input features to the correct output class using an artificial neural network. For this task, we have chosen the following simple architecture with one input layer with 4 neurons (one for each feature), one hidden layer with 4 neurons and one output layer with 3 neurons (one for each class), all fully connected:
Our artificial neural network will have a total of 35 parameters: 4 for each input neuron connected to the hidden layer, plus an additional 4 for the associated first bias neuron and 3 for each of the hidden neurons connected to the output layer, plus an additional 3 for the associated second bias neuron. I.e. \(4 \times 4+4+4 \ times 3+3=35\)
Prepare data
We start with slightly wrangling the iris data set by renaming and scaling the features and converting character labels to numeric:
set.seed(265509)
nn_dat <- iris |>
as_tibble() |>
mutate(sepal_length = scale(Sepal.Length),
sepal_width = scale(Sepal.Width),
petal_length = scale(Petal.Length),
petal_width = scale(Petal.Width),
class_label = as.numeric(Species) - 1) |>
select(sepal_length, sepal_width, petal_length, petal_width, class_label)
nn_dat |> head(3)# A tibble: 3 × 5
sepal_length[,1] sepal_width[,1] petal_length[,1] petal_width[,1] class_label
<dbl> <dbl> <dbl> <dbl> <dbl>
1 -0.898 1.02 -1.34 -1.31 0
2 -1.14 -0.132 -1.34 -1.31 0
3 -1.38 0.327 -1.39 -1.31 0
Then, we create indices for splitting the iris data into a training and a test data set. We set aside 20% of the data for testing:
test_fraction <- 0.20
n_total_samples <- nrow(nn_dat)
n_train_samples <- ceiling((1 - test_fraction) * n_total_samples)
train_indices <- sample(n_total_samples, n_train_samples)
n_test_samples <- n_total_samples - n_train_samples
test_indices <- setdiff(seq(1, n_train_samples), train_indices)Based on the indices, we can now create training and test data
x_train <- nn_dat |>
select(-class_label) |>
as.matrix() |>
(\(m) m[train_indices, ])()
y_train <- nn_dat |>
slice(train_indices) |>
pull(class_label) |>
to_categorical(num_classes = 3)
x_test <- nn_dat |>
select(-class_label) |>
as.matrix() |>
(\(m) m[test_indices, ])()
y_test <- nn_dat |>
slice(test_indices) |>
pull(class_label) |>
to_categorical(num_classes = 3)Set Architecture
With the data in place, we now set the architecture of our artifical neural network:
model <- keras_model_sequential()
model |>
layer_dense(units = 4, activation = 'relu', input_shape = 4) |>
layer_dense(units = 3, activation = 'softmax')
model |> summary()Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense) │ (None, 4) │ 20 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_1 (Dense) │ (None, 3) │ 15 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
Total params: 35 (140.00 B)
Trainable params: 35 (140.00 B)
Non-trainable params: 0 (0.00 B)
Next, the architecture set in the model needs to be compiled:
model |> compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy')
)Train the Artificial Neural Network
Lastly we fit the model and save the training progres in the history object:
history <- model |> fit(
x = x_train, y = y_train,
epochs = 200,
batch_size = 20,
validation_split = 0
)Epoch 1/200
6/6 - 1s - 142ms/step - accuracy: 0.5250 - loss: 0.9912
Epoch 2/200
6/6 - 0s - 7ms/step - accuracy: 0.5333 - loss: 0.9721
Epoch 3/200
6/6 - 0s - 6ms/step - accuracy: 0.5500 - loss: 0.9581
Epoch 4/200
6/6 - 0s - 6ms/step - accuracy: 0.5583 - loss: 0.9455
Epoch 5/200
6/6 - 0s - 6ms/step - accuracy: 0.5833 - loss: 0.9330
Epoch 6/200
6/6 - 0s - 6ms/step - accuracy: 0.6000 - loss: 0.9208
Epoch 7/200
6/6 - 0s - 6ms/step - accuracy: 0.6083 - loss: 0.9085
Epoch 8/200
6/6 - 0s - 6ms/step - accuracy: 0.6083 - loss: 0.8969
Epoch 9/200
6/6 - 0s - 6ms/step - accuracy: 0.6083 - loss: 0.8851
Epoch 10/200
6/6 - 0s - 6ms/step - accuracy: 0.6083 - loss: 0.8736
Epoch 11/200
6/6 - 0s - 6ms/step - accuracy: 0.6333 - loss: 0.8627
Epoch 12/200
6/6 - 0s - 6ms/step - accuracy: 0.6167 - loss: 0.8517
Epoch 13/200
6/6 - 0s - 5ms/step - accuracy: 0.6167 - loss: 0.8406
Epoch 14/200
6/6 - 0s - 5ms/step - accuracy: 0.6167 - loss: 0.8298
Epoch 15/200
6/6 - 0s - 5ms/step - accuracy: 0.6333 - loss: 0.8189
Epoch 16/200
6/6 - 0s - 6ms/step - accuracy: 0.6333 - loss: 0.8078
Epoch 17/200
6/6 - 0s - 6ms/step - accuracy: 0.6250 - loss: 0.7969
Epoch 18/200
6/6 - 0s - 6ms/step - accuracy: 0.6250 - loss: 0.7858
Epoch 19/200
6/6 - 0s - 5ms/step - accuracy: 0.6417 - loss: 0.7747
Epoch 20/200
6/6 - 0s - 5ms/step - accuracy: 0.6500 - loss: 0.7644
Epoch 21/200
6/6 - 0s - 5ms/step - accuracy: 0.6417 - loss: 0.7538
Epoch 22/200
6/6 - 0s - 5ms/step - accuracy: 0.6500 - loss: 0.7434
Epoch 23/200
6/6 - 0s - 5ms/step - accuracy: 0.6583 - loss: 0.7331
Epoch 24/200
6/6 - 0s - 7ms/step - accuracy: 0.6750 - loss: 0.7232
Epoch 25/200
6/6 - 0s - 6ms/step - accuracy: 0.6667 - loss: 0.7131
Epoch 26/200
6/6 - 0s - 6ms/step - accuracy: 0.6750 - loss: 0.7031
Epoch 27/200
6/6 - 0s - 6ms/step - accuracy: 0.6750 - loss: 0.6934
Epoch 28/200
6/6 - 0s - 6ms/step - accuracy: 0.6750 - loss: 0.6837
Epoch 29/200
6/6 - 0s - 5ms/step - accuracy: 0.6833 - loss: 0.6744
Epoch 30/200
6/6 - 0s - 5ms/step - accuracy: 0.6750 - loss: 0.6651
Epoch 31/200
6/6 - 0s - 5ms/step - accuracy: 0.6833 - loss: 0.6565
Epoch 32/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6475
Epoch 33/200
6/6 - 0s - 5ms/step - accuracy: 0.7000 - loss: 0.6390
Epoch 34/200
6/6 - 0s - 5ms/step - accuracy: 0.7000 - loss: 0.6302
Epoch 35/200
6/6 - 0s - 5ms/step - accuracy: 0.7000 - loss: 0.6219
Epoch 36/200
6/6 - 0s - 5ms/step - accuracy: 0.7167 - loss: 0.6136
Epoch 37/200
6/6 - 0s - 5ms/step - accuracy: 0.7333 - loss: 0.6055
Epoch 38/200
6/6 - 0s - 6ms/step - accuracy: 0.7417 - loss: 0.5977
Epoch 39/200
6/6 - 0s - 7ms/step - accuracy: 0.7583 - loss: 0.5898
Epoch 40/200
6/6 - 0s - 6ms/step - accuracy: 0.7750 - loss: 0.5822
Epoch 41/200
6/6 - 0s - 6ms/step - accuracy: 0.7750 - loss: 0.5748
Epoch 42/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.5676
Epoch 43/200
6/6 - 0s - 5ms/step - accuracy: 0.7833 - loss: 0.5607
Epoch 44/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.5539
Epoch 45/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.5473
Epoch 46/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.5408
Epoch 47/200
6/6 - 0s - 5ms/step - accuracy: 0.7833 - loss: 0.5345
Epoch 48/200
6/6 - 0s - 5ms/step - accuracy: 0.7917 - loss: 0.5285
Epoch 49/200
6/6 - 0s - 6ms/step - accuracy: 0.8000 - loss: 0.5225
Epoch 50/200
6/6 - 0s - 6ms/step - accuracy: 0.7917 - loss: 0.5168
Epoch 51/200
6/6 - 0s - 5ms/step - accuracy: 0.8000 - loss: 0.5110
Epoch 52/200
6/6 - 0s - 5ms/step - accuracy: 0.8000 - loss: 0.5057
Epoch 53/200
6/6 - 0s - 5ms/step - accuracy: 0.8083 - loss: 0.5005
Epoch 54/200
6/6 - 0s - 5ms/step - accuracy: 0.8083 - loss: 0.4953
Epoch 55/200
6/6 - 0s - 6ms/step - accuracy: 0.8083 - loss: 0.4904
Epoch 56/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.4860
Epoch 57/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.4813
Epoch 58/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.4768
Epoch 59/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.4725
Epoch 60/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.4680
Epoch 61/200
6/6 - 0s - 5ms/step - accuracy: 0.8250 - loss: 0.4637
Epoch 62/200
6/6 - 0s - 5ms/step - accuracy: 0.8250 - loss: 0.4595
Epoch 63/200
6/6 - 0s - 5ms/step - accuracy: 0.8250 - loss: 0.4550
Epoch 64/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.4511
Epoch 65/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.4469
Epoch 66/200
6/6 - 0s - 5ms/step - accuracy: 0.8250 - loss: 0.4428
Epoch 67/200
6/6 - 0s - 5ms/step - accuracy: 0.8333 - loss: 0.4389
Epoch 68/200
6/6 - 0s - 5ms/step - accuracy: 0.8500 - loss: 0.4352
Epoch 69/200
6/6 - 0s - 5ms/step - accuracy: 0.8417 - loss: 0.4313
Epoch 70/200
6/6 - 0s - 5ms/step - accuracy: 0.8500 - loss: 0.4280
Epoch 71/200
6/6 - 0s - 6ms/step - accuracy: 0.8583 - loss: 0.4246
Epoch 72/200
6/6 - 0s - 5ms/step - accuracy: 0.8583 - loss: 0.4209
Epoch 73/200
6/6 - 0s - 6ms/step - accuracy: 0.8583 - loss: 0.4178
Epoch 74/200
6/6 - 0s - 5ms/step - accuracy: 0.8583 - loss: 0.4146
Epoch 75/200
6/6 - 0s - 5ms/step - accuracy: 0.8583 - loss: 0.4117
Epoch 76/200
6/6 - 0s - 5ms/step - accuracy: 0.8583 - loss: 0.4084
Epoch 77/200
6/6 - 0s - 5ms/step - accuracy: 0.8667 - loss: 0.4056
Epoch 78/200
6/6 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.4024
Epoch 79/200
6/6 - 0s - 5ms/step - accuracy: 0.8667 - loss: 0.3996
Epoch 80/200
6/6 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.3969
Epoch 81/200
6/6 - 0s - 6ms/step - accuracy: 0.8750 - loss: 0.3938
Epoch 82/200
6/6 - 0s - 5ms/step - accuracy: 0.8750 - loss: 0.3909
Epoch 83/200
6/6 - 0s - 6ms/step - accuracy: 0.8833 - loss: 0.3880
Epoch 84/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.3849
Epoch 85/200
6/6 - 0s - 5ms/step - accuracy: 0.8833 - loss: 0.3819
Epoch 86/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.3790
Epoch 87/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.3759
Epoch 88/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3732
Epoch 89/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3703
Epoch 90/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3673
Epoch 91/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3648
Epoch 92/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3618
Epoch 93/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3596
Epoch 94/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3568
Epoch 95/200
6/6 - 0s - 6ms/step - accuracy: 0.9083 - loss: 0.3545
Epoch 96/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3524
Epoch 97/200
6/6 - 0s - 6ms/step - accuracy: 0.9083 - loss: 0.3499
Epoch 98/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3478
Epoch 99/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3454
Epoch 100/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3430
Epoch 101/200
6/6 - 0s - 6ms/step - accuracy: 0.9083 - loss: 0.3410
Epoch 102/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3387
Epoch 103/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3365
Epoch 104/200
6/6 - 0s - 6ms/step - accuracy: 0.8917 - loss: 0.3343
Epoch 105/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3323
Epoch 106/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3300
Epoch 107/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3279
Epoch 108/200
6/6 - 0s - 5ms/step - accuracy: 0.8917 - loss: 0.3258
Epoch 109/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3239
Epoch 110/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3216
Epoch 111/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3197
Epoch 112/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3175
Epoch 113/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3154
Epoch 114/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3137
Epoch 115/200
6/6 - 0s - 7ms/step - accuracy: 0.9000 - loss: 0.3115
Epoch 116/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.3098
Epoch 117/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3080
Epoch 118/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3060
Epoch 119/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3040
Epoch 120/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3020
Epoch 121/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.3000
Epoch 122/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.2981
Epoch 123/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2964
Epoch 124/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2945
Epoch 125/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2925
Epoch 126/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2909
Epoch 127/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2891
Epoch 128/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2872
Epoch 129/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.2856
Epoch 130/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2839
Epoch 131/200
6/6 - 0s - 6ms/step - accuracy: 0.9083 - loss: 0.2823
Epoch 132/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2804
Epoch 133/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2789
Epoch 134/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2770
Epoch 135/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.2753
Epoch 136/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2734
Epoch 137/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2716
Epoch 138/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2699
Epoch 139/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2684
Epoch 140/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2665
Epoch 141/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2652
Epoch 142/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2635
Epoch 143/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2619
Epoch 144/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2605
Epoch 145/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2589
Epoch 146/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2573
Epoch 147/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.2561
Epoch 148/200
6/6 - 0s - 7ms/step - accuracy: 0.9250 - loss: 0.2545
Epoch 149/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2531
Epoch 150/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.2515
Epoch 151/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.2502
Epoch 152/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2486
Epoch 153/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2470
Epoch 154/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2458
Epoch 155/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2444
Epoch 156/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2428
Epoch 157/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2414
Epoch 158/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2403
Epoch 159/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2387
Epoch 160/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2374
Epoch 161/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2362
Epoch 162/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2350
Epoch 163/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2334
Epoch 164/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2324
Epoch 165/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2312
Epoch 166/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2302
Epoch 167/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2288
Epoch 168/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2278
Epoch 169/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2267
Epoch 170/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2255
Epoch 171/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2246
Epoch 172/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2230
Epoch 173/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2219
Epoch 174/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2206
Epoch 175/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2196
Epoch 176/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2180
Epoch 177/200
6/6 - 0s - 7ms/step - accuracy: 0.9417 - loss: 0.2168
Epoch 178/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2156
Epoch 179/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2140
Epoch 180/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2132
Epoch 181/200
6/6 - 0s - 7ms/step - accuracy: 0.9417 - loss: 0.2117
Epoch 182/200
6/6 - 0s - 7ms/step - accuracy: 0.9417 - loss: 0.2103
Epoch 183/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2091
Epoch 184/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2075
Epoch 185/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2065
Epoch 186/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2053
Epoch 187/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2038
Epoch 188/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2026
Epoch 189/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.2017
Epoch 190/200
6/6 - 0s - 6ms/step - accuracy: 0.9417 - loss: 0.2002
Epoch 191/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.1990
Epoch 192/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.1977
Epoch 193/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.1965
Epoch 194/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.1955
Epoch 195/200
6/6 - 0s - 5ms/step - accuracy: 0.9417 - loss: 0.1941
Epoch 196/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.1929
Epoch 197/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.1920
Epoch 198/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.1906
Epoch 199/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.1895
Epoch 200/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.1883
plot(history) +
ggtitle("Training a neural network based classifier on the iris data set") +
theme_bw()Evaluate Network Performance
The final performance can be obtained like so:
perf <- model |> evaluate(x_test, y_test)1/1 - 0s - 297ms/step - accuracy: 0.9565 - loss: 0.1386
print(perf)$accuracy
[1] 0.9565217
$loss
[1] 0.1386385
classes <- iris |>
as_tibble() |>
pull(Species) |>
unique()
y_pred <- model |>
predict(x_test) |>
op_argmax(axis = -1) |>
as.numeric() - 11/1 - 0s - 139ms/step
y_true <- nn_dat |>
slice(test_indices) |>
pull(class_label)
tibble(
y_true = classes[y_true + 1],
y_pred = classes[y_pred + 1],
Correct = factor(ifelse(y_true == y_pred, "Yes", "No"))
) |>
ggplot(aes(x = y_true, y = y_pred, colour = Correct)) +
geom_jitter() +
theme_bw() +
ggtitle(label = "Classification Performance of Artificial Neural Network",
subtitle = str_c("Accuracy = ",round(perf$accuracy,3)*100,"%")) +
xlab(label = "True iris class") +
ylab(label = "Predicted iris class")library(gmodels)
CrossTable(y_pred, y_true,
prop.chisq = FALSE, prop.t = FALSE, prop.r = FALSE,
dnn = c('predicted', 'actual'))
Cell Contents
|-------------------------|
| N |
| N / Col Total |
|-------------------------|
Total Observations in Table: 23
| actual
predicted | 0 | 1 | 2 | Row Total |
-------------|-----------|-----------|-----------|-----------|
0 | 12 | 0 | 0 | 12 |
| 1.000 | 0.000 | 0.000 | |
-------------|-----------|-----------|-----------|-----------|
1 | 0 | 6 | 1 | 7 |
| 0.000 | 1.000 | 0.200 | |
-------------|-----------|-----------|-----------|-----------|
2 | 0 | 0 | 4 | 4 |
| 0.000 | 0.000 | 0.800 | |
-------------|-----------|-----------|-----------|-----------|
Column Total | 12 | 6 | 5 | 23 |
| 0.522 | 0.261 | 0.217 | |
-------------|-----------|-----------|-----------|-----------|
Conclusion
I hope this illustrated just how easy it is to get started building artificial neural network using Keras and TensorFlow in R. With relative ease, we created a 3-class predictor with an accuracy of 100%. This was a basic minimal example. The network can be expanded to create Deep Learning networks and also the entire TensorFlow API is available.
Enjoy and Happy Learning!
Leon
Thanks again Leon, this was awsome!!!