# install.packages("keras3")Building a simple neural network using Keras and Tensorflow - Updated
Update: The original code has been updated to use the tidymodels init_split() function, rather than using the indices method which originally used setdiff, which now may have a conflict between base R and the tidyverse.
Thank you
A big thank you to Leon Jessen for posting his code on github.
Building a simple neural network using Keras and Tensorflow
I have forked his project on github and put his code into an R Notebook so we can run it in class.
Motivation
The following is a minimal example for building your first simple artificial neural network using Keras and TensorFlow for R.
Getting started - Install Keras and TensorFlow for R
You can install the Keras for R package from CRAN as follows:
TensorFlow is the default backend engine. TensorFlow and Keras can be installed as follows:
# library(keras3)
# install_keras(backend = "tensorflow")Naturally, we will also need Tidyverse:
# install.packages("tidyverse")Once installed, we simply load the libraries
library("keras3")
library("tidyverse")── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr 1.1.4 ✔ readr 2.1.5
✔ forcats 1.0.1 ✔ stringr 1.5.2
✔ ggplot2 4.0.2 ✔ tibble 3.3.0
✔ lubridate 1.9.4 ✔ tidyr 1.3.1
✔ purrr 1.2.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
Artificial Neural Network Using the Iris Data Set
Right, let’s get to it!
Data
The famous (Fisher’s or Anderson’s) iris data set contains a total of 150 observations of 4 input features Sepal.Length, Sepal.Width, Petal.Length and Petal.Width and 3 output classes setosa versicolor and virginica, with 50 observations in each class. The distributions of the feature values looks like so:
iris |>
as_tibble() |>
pivot_longer(
cols = -Species,
names_to = "feature",
values_to = "value"
) |>
ggplot(aes(x = feature, y = value, fill = Species)) +
geom_violin(alpha = 0.5, scale = "width") +
theme_bw()Our aim is to connect the 4 input features to the correct output class using an artificial neural network. For this task, we have chosen the following simple architecture with one input layer with 4 neurons (one for each feature), one hidden layer with 4 neurons and one output layer with 3 neurons (one for each class), all fully connected:
Our artificial neural network will have a total of 35 parameters: 4 for each input neuron connected to the hidden layer, plus an additional 4 for the associated first bias neuron and 3 for each of the hidden neurons connected to the output layer, plus an additional 3 for the associated second bias neuron. I.e. \(4 \times 4+4+4 \ times 3+3=35\)
Prepare data
We start with slightly wrangling the iris data set by renaming and scaling the features and converting character labels to numeric:
set.seed(265509)
nn_dat <- iris |>
as_tibble() |>
mutate(sepal_length = scale(Sepal.Length),
sepal_width = scale(Sepal.Width),
petal_length = scale(Petal.Length),
petal_width = scale(Petal.Width),
class_label = as.numeric(Species) - 1) |>
select(sepal_length, sepal_width, petal_length, petal_width, class_label)
nn_dat |> head(3)# A tibble: 3 × 5
sepal_length[,1] sepal_width[,1] petal_length[,1] petal_width[,1] class_label
<dbl> <dbl> <dbl> <dbl> <dbl>
1 -0.898 1.02 -1.34 -1.31 0
2 -1.14 -0.132 -1.34 -1.31 0
3 -1.38 0.327 -1.39 -1.31 0
Then, we create indices for splitting the iris data into a training and a test data set. We set aside 20% of the data for testing:
library(tidymodels)── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
✔ broom 1.0.10 ✔ rsample 1.3.1
✔ dials 1.4.2 ✔ tailor 0.1.0
✔ infer 1.0.9 ✔ tune 2.0.1
✔ modeldata 1.5.1 ✔ workflows 1.3.0
✔ parsnip 1.3.3 ✔ workflowsets 1.1.1
✔ recipes 1.3.1 ✔ yardstick 1.3.2
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard() masks purrr::discard()
✖ dplyr::filter() masks stats::filter()
✖ recipes::fixed() masks stringr::fixed()
✖ infer::generate() masks keras3::generate()
✖ yardstick::get_weights() masks keras3::get_weights()
✖ dplyr::lag() masks stats::lag()
✖ yardstick::spec() masks readr::spec()
✖ recipes::step() masks stats::step()
set.seed(364)
n <- nrow(nn_dat)
n[1] 150
iris_parts <- nn_dat |>
initial_split(prop = 0.8)
train <- iris_parts |>
training()
test <- iris_parts |>
testing()
list(train, test) |>
map_int(nrow)[1] 120 30
n_total_samples <- nrow(nn_dat)
n_train_samples <- nrow(train)
n_test_samples <- nrow(test)Based on the indices, we can now create training and test data
x_train <- train |>
select(-class_label) |>
as.matrix()
y_train <- train|>
select(class_label) |>
as.matrix() |>
to_categorical()
x_test <- test |>
select(-class_label) |>
as.matrix()
y_test <- test |>
select(class_label) |>
as.matrix() |>
to_categorical()
dim(y_train)[1] 120 3
dim(y_test)[1] 30 3
Set Architecture
With the data in place, we now set the architecture of our artifical neural network:
model <- keras_model_sequential()
model |>
layer_dense(units = 4, activation = 'relu', input_shape = 4) |>
layer_dense(units = 3, activation = 'softmax')
model |> summary()Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense) │ (None, 4) │ 20 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_1 (Dense) │ (None, 3) │ 15 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
Total params: 35 (140.00 B)
Trainable params: 35 (140.00 B)
Non-trainable params: 0 (0.00 B)
Next, the architecture set in the model needs to be compiled:
model |> compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy')
)Train the Artificial Neural Network
Lastly we fit the model and save the training progres in the history object:
history <- model |> fit(
x = x_train, y = y_train,
epochs = 200,
batch_size = 20,
validation_split = 0
)Epoch 1/200
6/6 - 1s - 124ms/step - accuracy: 0.2667 - loss: 1.5846
Epoch 2/200
6/6 - 0s - 6ms/step - accuracy: 0.2917 - loss: 1.5461
Epoch 3/200
6/6 - 0s - 6ms/step - accuracy: 0.3250 - loss: 1.5185
Epoch 4/200
6/6 - 0s - 5ms/step - accuracy: 0.3583 - loss: 1.4938
Epoch 5/200
6/6 - 0s - 5ms/step - accuracy: 0.3750 - loss: 1.4701
Epoch 6/200
6/6 - 0s - 5ms/step - accuracy: 0.3750 - loss: 1.4468
Epoch 7/200
6/6 - 0s - 5ms/step - accuracy: 0.3917 - loss: 1.4239
Epoch 8/200
6/6 - 0s - 5ms/step - accuracy: 0.4250 - loss: 1.4014
Epoch 9/200
6/6 - 0s - 5ms/step - accuracy: 0.4417 - loss: 1.3789
Epoch 10/200
6/6 - 0s - 5ms/step - accuracy: 0.4500 - loss: 1.3565
Epoch 11/200
6/6 - 0s - 5ms/step - accuracy: 0.4583 - loss: 1.3345
Epoch 12/200
6/6 - 0s - 5ms/step - accuracy: 0.4667 - loss: 1.3131
Epoch 13/200
6/6 - 0s - 5ms/step - accuracy: 0.4583 - loss: 1.2914
Epoch 14/200
6/6 - 0s - 5ms/step - accuracy: 0.4833 - loss: 1.2699
Epoch 15/200
6/6 - 0s - 5ms/step - accuracy: 0.4917 - loss: 1.2491
Epoch 16/200
6/6 - 0s - 5ms/step - accuracy: 0.4917 - loss: 1.2284
Epoch 17/200
6/6 - 0s - 5ms/step - accuracy: 0.5000 - loss: 1.2080
Epoch 18/200
6/6 - 0s - 5ms/step - accuracy: 0.5167 - loss: 1.1879
Epoch 19/200
6/6 - 0s - 5ms/step - accuracy: 0.5333 - loss: 1.1680
Epoch 20/200
6/6 - 0s - 5ms/step - accuracy: 0.5333 - loss: 1.1481
Epoch 21/200
6/6 - 0s - 5ms/step - accuracy: 0.5417 - loss: 1.1286
Epoch 22/200
6/6 - 0s - 5ms/step - accuracy: 0.5500 - loss: 1.1095
Epoch 23/200
6/6 - 0s - 5ms/step - accuracy: 0.5417 - loss: 1.0900
Epoch 24/200
6/6 - 0s - 5ms/step - accuracy: 0.5500 - loss: 1.0711
Epoch 25/200
6/6 - 0s - 5ms/step - accuracy: 0.5583 - loss: 1.0525
Epoch 26/200
6/6 - 0s - 5ms/step - accuracy: 0.5667 - loss: 1.0342
Epoch 27/200
6/6 - 0s - 5ms/step - accuracy: 0.5583 - loss: 1.0162
Epoch 28/200
6/6 - 0s - 5ms/step - accuracy: 0.5750 - loss: 0.9988
Epoch 29/200
6/6 - 0s - 5ms/step - accuracy: 0.5750 - loss: 0.9815
Epoch 30/200
6/6 - 0s - 5ms/step - accuracy: 0.6000 - loss: 0.9647
Epoch 31/200
6/6 - 0s - 5ms/step - accuracy: 0.6167 - loss: 0.9482
Epoch 32/200
6/6 - 0s - 5ms/step - accuracy: 0.6250 - loss: 0.9322
Epoch 33/200
6/6 - 0s - 5ms/step - accuracy: 0.6167 - loss: 0.9164
Epoch 34/200
6/6 - 0s - 5ms/step - accuracy: 0.6250 - loss: 0.9011
Epoch 35/200
6/6 - 0s - 5ms/step - accuracy: 0.6333 - loss: 0.8862
Epoch 36/200
6/6 - 0s - 5ms/step - accuracy: 0.6417 - loss: 0.8712
Epoch 37/200
6/6 - 0s - 5ms/step - accuracy: 0.6417 - loss: 0.8569
Epoch 38/200
6/6 - 0s - 5ms/step - accuracy: 0.6500 - loss: 0.8429
Epoch 39/200
6/6 - 0s - 6ms/step - accuracy: 0.6583 - loss: 0.8292
Epoch 40/200
6/6 - 0s - 5ms/step - accuracy: 0.6583 - loss: 0.8160
Epoch 41/200
6/6 - 0s - 6ms/step - accuracy: 0.6583 - loss: 0.8029
Epoch 42/200
6/6 - 0s - 5ms/step - accuracy: 0.6583 - loss: 0.7905
Epoch 43/200
6/6 - 0s - 5ms/step - accuracy: 0.6750 - loss: 0.7782
Epoch 44/200
6/6 - 0s - 5ms/step - accuracy: 0.6750 - loss: 0.7662
Epoch 45/200
6/6 - 0s - 5ms/step - accuracy: 0.6833 - loss: 0.7546
Epoch 46/200
6/6 - 0s - 5ms/step - accuracy: 0.7083 - loss: 0.7430
Epoch 47/200
6/6 - 0s - 5ms/step - accuracy: 0.7250 - loss: 0.7320
Epoch 48/200
6/6 - 0s - 5ms/step - accuracy: 0.7250 - loss: 0.7216
Epoch 49/200
6/6 - 0s - 5ms/step - accuracy: 0.7417 - loss: 0.7113
Epoch 50/200
6/6 - 0s - 7ms/step - accuracy: 0.7583 - loss: 0.7013
Epoch 51/200
6/6 - 0s - 5ms/step - accuracy: 0.7667 - loss: 0.6915
Epoch 52/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.6821
Epoch 53/200
6/6 - 0s - 5ms/step - accuracy: 0.7833 - loss: 0.6730
Epoch 54/200
6/6 - 0s - 5ms/step - accuracy: 0.8083 - loss: 0.6641
Epoch 55/200
6/6 - 0s - 5ms/step - accuracy: 0.8083 - loss: 0.6556
Epoch 56/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.6474
Epoch 57/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.6396
Epoch 58/200
6/6 - 0s - 5ms/step - accuracy: 0.8250 - loss: 0.6319
Epoch 59/200
6/6 - 0s - 5ms/step - accuracy: 0.8583 - loss: 0.6247
Epoch 60/200
6/6 - 0s - 5ms/step - accuracy: 0.8667 - loss: 0.6176
Epoch 61/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.6108
Epoch 62/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.6043
Epoch 63/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5979
Epoch 64/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.5918
Epoch 65/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5856
Epoch 66/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.5801
Epoch 67/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5742
Epoch 68/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5689
Epoch 69/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.5637
Epoch 70/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5587
Epoch 71/200
6/6 - 0s - 5ms/step - accuracy: 0.9000 - loss: 0.5541
Epoch 72/200
6/6 - 0s - 6ms/step - accuracy: 0.9000 - loss: 0.5492
Epoch 73/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.5447
Epoch 74/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.5408
Epoch 75/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.5365
Epoch 76/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.5324
Epoch 77/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.5286
Epoch 78/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.5247
Epoch 79/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.5209
Epoch 80/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.5171
Epoch 81/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.5137
Epoch 82/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.5099
Epoch 83/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.5065
Epoch 84/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.5031
Epoch 85/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.4995
Epoch 86/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4963
Epoch 87/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4930
Epoch 88/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4897
Epoch 89/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4865
Epoch 90/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4834
Epoch 91/200
6/6 - 0s - 6ms/step - accuracy: 0.9167 - loss: 0.4802
Epoch 92/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4771
Epoch 93/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4741
Epoch 94/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4708
Epoch 95/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4675
Epoch 96/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4644
Epoch 97/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4612
Epoch 98/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4579
Epoch 99/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4548
Epoch 100/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4516
Epoch 101/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4486
Epoch 102/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4456
Epoch 103/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4425
Epoch 104/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4394
Epoch 105/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4364
Epoch 106/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4332
Epoch 107/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4300
Epoch 108/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4268
Epoch 109/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4236
Epoch 110/200
6/6 - 0s - 5ms/step - accuracy: 0.9083 - loss: 0.4207
Epoch 111/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.4174
Epoch 112/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4144
Epoch 113/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4114
Epoch 114/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4084
Epoch 115/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.4054
Epoch 116/200
6/6 - 0s - 6ms/step - accuracy: 0.9250 - loss: 0.4022
Epoch 117/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3989
Epoch 118/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3957
Epoch 119/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3925
Epoch 120/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3891
Epoch 121/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3862
Epoch 122/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3828
Epoch 123/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3796
Epoch 124/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3763
Epoch 125/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3730
Epoch 126/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3700
Epoch 127/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3665
Epoch 128/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3635
Epoch 129/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3605
Epoch 130/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3573
Epoch 131/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3542
Epoch 132/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3513
Epoch 133/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.3482
Epoch 134/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3450
Epoch 135/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3417
Epoch 136/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3384
Epoch 137/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3353
Epoch 138/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3322
Epoch 139/200
6/6 - 0s - 6ms/step - accuracy: 0.9167 - loss: 0.3290
Epoch 140/200
6/6 - 0s - 6ms/step - accuracy: 0.9167 - loss: 0.3257
Epoch 141/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3224
Epoch 142/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3194
Epoch 143/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3163
Epoch 144/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3135
Epoch 145/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3103
Epoch 146/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3078
Epoch 147/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3049
Epoch 148/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.3023
Epoch 149/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2997
Epoch 150/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2970
Epoch 151/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2942
Epoch 152/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2915
Epoch 153/200
6/6 - 0s - 6ms/step - accuracy: 0.9167 - loss: 0.2886
Epoch 154/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2859
Epoch 155/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2831
Epoch 156/200
6/6 - 0s - 5ms/step - accuracy: 0.9167 - loss: 0.2803
Epoch 157/200
6/6 - 0s - 5ms/step - accuracy: 0.9250 - loss: 0.2777
Epoch 158/200
6/6 - 0s - 6ms/step - accuracy: 0.9167 - loss: 0.2748
Epoch 159/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2722
Epoch 160/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2698
Epoch 161/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2670
Epoch 162/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2645
Epoch 163/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2622
Epoch 164/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2595
Epoch 165/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2571
Epoch 166/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2546
Epoch 167/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2520
Epoch 168/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2495
Epoch 169/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2475
Epoch 170/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2451
Epoch 171/200
6/6 - 0s - 5ms/step - accuracy: 0.9333 - loss: 0.2428
Epoch 172/200
6/6 - 0s - 6ms/step - accuracy: 0.9500 - loss: 0.2409
Epoch 173/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2389
Epoch 174/200
6/6 - 0s - 6ms/step - accuracy: 0.9500 - loss: 0.2366
Epoch 175/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2344
Epoch 176/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2322
Epoch 177/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2301
Epoch 178/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2281
Epoch 179/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2256
Epoch 180/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2241
Epoch 181/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2213
Epoch 182/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2194
Epoch 183/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2172
Epoch 184/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2149
Epoch 185/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2129
Epoch 186/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2108
Epoch 187/200
6/6 - 0s - 6ms/step - accuracy: 0.9500 - loss: 0.2086
Epoch 188/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2066
Epoch 189/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.2045
Epoch 190/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.2025
Epoch 191/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.2002
Epoch 192/200
6/6 - 0s - 5ms/step - accuracy: 0.9500 - loss: 0.1980
Epoch 193/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.1964
Epoch 194/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.1945
Epoch 195/200
6/6 - 0s - 5ms/step - accuracy: 0.9583 - loss: 0.1923
Epoch 196/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.1906
Epoch 197/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.1888
Epoch 198/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.1869
Epoch 199/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.1850
Epoch 200/200
6/6 - 0s - 5ms/step - accuracy: 0.9667 - loss: 0.1835
plot(history) +
ggtitle("Training a neural network based classifier on the iris data set") +
theme_bw()Evaluate Network Performance
The final performance can be obtained like so:
perf <- model |> evaluate(x_test, y_test)1/1 - 0s - 290ms/step - accuracy: 0.9667 - loss: 0.1599
print(perf)$accuracy
[1] 0.9666666
$loss
[1] 0.1598504
classes <- iris |>
as_tibble() |>
pull(Species) |>
unique()
y_pred <- model |>
predict(x_test) |>
op_argmax(axis = -1) |>
as.numeric() - 11/1 - 0s - 120ms/step
y_true <- test |>
select(class_label) |>
unlist() |>
as.numeric()
tibble(
y_true = classes[y_true + 1],
y_pred = classes[y_pred + 1],
Correct = factor(ifelse(y_true == y_pred, "Yes", "No"))
) |>
ggplot(aes(x = y_true, y = y_pred, colour = Correct)) +
geom_jitter() +
theme_bw() +
ggtitle(label = "Classification Performance of Artificial Neural Network",
subtitle = str_c("Accuracy = ",round(perf$accuracy,3)*100,"%")) +
xlab(label = "True iris class") +
ylab(label = "Predicted iris class")library(gmodels)
CrossTable(y_pred, y_true,
prop.chisq = FALSE, prop.t = FALSE, prop.r = FALSE,
dnn = c('predicted', 'actual'))
Cell Contents
|-------------------------|
| N |
| N / Col Total |
|-------------------------|
Total Observations in Table: 30
| actual
predicted | 0 | 1 | 2 | Row Total |
-------------|-----------|-----------|-----------|-----------|
0 | 10 | 0 | 0 | 10 |
| 1.000 | 0.000 | 0.000 | |
-------------|-----------|-----------|-----------|-----------|
1 | 0 | 9 | 1 | 10 |
| 0.000 | 1.000 | 0.091 | |
-------------|-----------|-----------|-----------|-----------|
2 | 0 | 0 | 10 | 10 |
| 0.000 | 0.000 | 0.909 | |
-------------|-----------|-----------|-----------|-----------|
Column Total | 10 | 9 | 11 | 30 |
| 0.333 | 0.300 | 0.367 | |
-------------|-----------|-----------|-----------|-----------|
Conclusion
I hope this illustrated just how easy it is to get started building artificial neural network using Keras and TensorFlow in R. With relative ease, we created a 3-class predictor with an accuracy of 100%. This was a basic minimal example. The network can be expanded to create Deep Learning networks and also the entire TensorFlow API is available.
Enjoy and Happy Learning!
Leon
Thanks again Leon, this was awsome!!!