Building a simple neural network using Keras and Tensorflow - Updated

Author

Prof. Eric A. Suess

Update: The original code has been updated to use the tidymodels init_split() function, rather than using the indices method which originally used setdiff, which now may have a conflict between base R and the tidyverse.

Thank you

A big thank you to Leon Jessen for posting his code on github.

Building a simple neural network using Keras and Tensorflow

I have forked his project on github and put his code into an R Notebook so we can run it in class.

Motivation

The following is a minimal example for building your first simple artificial neural network using Keras and TensorFlow for R.

TensorFlow for R by RStudio lives here.

Getting started - Install Keras and TensorFlow for R

You can install the Keras for R package from CRAN as follows:

# install.packages("keras3")

TensorFlow is the default backend engine. TensorFlow and Keras can be installed as follows:

# library(keras3)
# install_keras()

Naturally, we will also need Tidyverse:

# install.packages("tidyverse")

Once installed, we simply load the libraries

library("keras3")
library("tidyverse")
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.4     ✔ readr     2.1.5
✔ forcats   1.0.1     ✔ stringr   1.5.2
✔ ggplot2   4.0.2     ✔ tibble    3.3.0
✔ lubridate 1.9.4     ✔ tidyr     1.3.1
✔ purrr     1.2.1     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors

Artificial Neural Network Using the Iris Data Set

Right, let’s get to it!

Data

The famous (Fisher’s or Anderson’s) iris data set contains a total of 150 observations of 4 input features Sepal.Length, Sepal.Width, Petal.Length and Petal.Width and 3 output classes setosa versicolor and virginica, with 50 observations in each class. The distributions of the feature values looks like so:

iris |> 
  as_tibble() |> 
  pivot_longer(
    cols = -Species, 
    names_to = "feature", 
    values_to = "value"
  ) |>
  ggplot(aes(x = feature, y = value, fill = Species)) +
  geom_violin(alpha = 0.5, scale = "width") +
  theme_bw()

Our aim is to connect the 4 input features to the correct output class using an artificial neural network. For this task, we have chosen the following simple architecture with one input layer with 4 neurons (one for each feature), one hidden layer with 4 neurons and one output layer with 3 neurons (one for each class), all fully connected:

architecture_visualisation.png

Our artificial neural network will have a total of 35 parameters: 4 for each input neuron connected to the hidden layer, plus an additional 4 for the associated first bias neuron and 3 for each of the hidden neurons connected to the output layer, plus an additional 3 for the associated second bias neuron. I.e. \(4 \times 4+4+4 \ times 3+3=35\)

Prepare data

We start with slightly wrangling the iris data set by renaming and scaling the features and converting character labels to numeric:

set.seed(265509)
nn_dat <- iris |> 
  as_tibble() |>
  mutate(sepal_length = scale(Sepal.Length),
         sepal_width  = scale(Sepal.Width),
         petal_length = scale(Petal.Length),
         petal_width  = scale(Petal.Width),          
         class_label  = as.numeric(Species) - 1) |> 
    select(sepal_length, sepal_width, petal_length, petal_width, class_label)

nn_dat |> head(3)
# A tibble: 3 × 5
  sepal_length[,1] sepal_width[,1] petal_length[,1] petal_width[,1] class_label
             <dbl>           <dbl>            <dbl>           <dbl>       <dbl>
1           -0.898           1.02             -1.34           -1.31           0
2           -1.14           -0.132            -1.34           -1.31           0
3           -1.38            0.327            -1.39           -1.31           0

Then, we create indices for splitting the iris data into a training and a test data set. We set aside 20% of the data for testing:

library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.4.1 ──
✔ broom        1.0.10     ✔ rsample      1.3.1 
✔ dials        1.4.2      ✔ tailor       0.1.0 
✔ infer        1.0.9      ✔ tune         2.0.1 
✔ modeldata    1.5.1      ✔ workflows    1.3.0 
✔ parsnip      1.3.3      ✔ workflowsets 1.1.1 
✔ recipes      1.3.1      ✔ yardstick    1.3.2 
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard()        masks purrr::discard()
✖ dplyr::filter()          masks stats::filter()
✖ recipes::fixed()         masks stringr::fixed()
✖ infer::generate()        masks keras3::generate()
✖ yardstick::get_weights() masks keras3::get_weights()
✖ dplyr::lag()             masks stats::lag()
✖ yardstick::spec()        masks readr::spec()
✖ recipes::step()          masks stats::step()
set.seed(364)
n <- nrow(nn_dat)
n
[1] 150
iris_parts <- nn_dat |> 
  initial_split(prop = 0.8)

train <- iris_parts |> 
  training()

test <- iris_parts |> 
  testing()

list(train, test) |> 
  map_int(nrow)
[1] 120  30
n_total_samples <- nrow(nn_dat)

n_train_samples <- nrow(train)

n_test_samples <- nrow(test)

Based on the indices, we can now create training and test data

x_train <- train |>
  select(-class_label) |>
  as.matrix()

y_train <- train|> 
  select(class_label) |> 
  as.matrix() |> 
  to_categorical()

x_test <- test |>  
  select(-class_label) |> 
  as.matrix()

y_test <- test |> 
  select(class_label) |> 
  as.matrix() |> 
  to_categorical() 

dim(y_train)
[1] 120   3
dim(y_test)
[1] 30  3

Set Architecture

With the data in place, we now set the architecture of our artifical neural network:

model <- keras_model_sequential()
model |> 
  layer_dense(units = 4, activation = 'relu', input_shape = 4) |> 
  layer_dense(units = 3, activation = 'softmax')
model |> summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                      ┃ Output Shape             ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense)                     │ (None, 4)                │            20 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_1 (Dense)                   │ (None, 3)                │            15 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
 Total params: 35 (140.00 B)
 Trainable params: 35 (140.00 B)
 Non-trainable params: 0 (0.00 B)

Next, the architecture set in the model needs to be compiled:

model |> compile(
  loss      = 'categorical_crossentropy',
  optimizer = optimizer_rmsprop(),
  metrics   = c('accuracy')
)

Train the Artificial Neural Network

Lastly we fit the model and save the training progres in the history object:

history <- model |> fit(
  x = x_train, y = y_train,
  epochs = 200,
  batch_size = 20,
  validation_split = 0
)
Epoch 1/200
6/6 - 1s - 140ms/step - accuracy: 0.0833 - loss: 2.0592
Epoch 2/200
6/6 - 0s - 8ms/step - accuracy: 0.0833 - loss: 1.9941
Epoch 3/200
6/6 - 0s - 7ms/step - accuracy: 0.0833 - loss: 1.9469
Epoch 4/200
6/6 - 0s - 7ms/step - accuracy: 0.0833 - loss: 1.9051
Epoch 5/200
6/6 - 0s - 6ms/step - accuracy: 0.0833 - loss: 1.8662
Epoch 6/200
6/6 - 0s - 6ms/step - accuracy: 0.0917 - loss: 1.8308
Epoch 7/200
6/6 - 0s - 6ms/step - accuracy: 0.0917 - loss: 1.7964
Epoch 8/200
6/6 - 0s - 6ms/step - accuracy: 0.0917 - loss: 1.7630
Epoch 9/200
6/6 - 0s - 6ms/step - accuracy: 0.0917 - loss: 1.7307
Epoch 10/200
6/6 - 0s - 5ms/step - accuracy: 0.0917 - loss: 1.6996
Epoch 11/200
6/6 - 0s - 5ms/step - accuracy: 0.0917 - loss: 1.6699
Epoch 12/200
6/6 - 0s - 5ms/step - accuracy: 0.0917 - loss: 1.6392
Epoch 13/200
6/6 - 0s - 6ms/step - accuracy: 0.0917 - loss: 1.6103
Epoch 14/200
6/6 - 0s - 6ms/step - accuracy: 0.1083 - loss: 1.5826
Epoch 15/200
6/6 - 0s - 7ms/step - accuracy: 0.1167 - loss: 1.5558
Epoch 16/200
6/6 - 0s - 6ms/step - accuracy: 0.1167 - loss: 1.5297
Epoch 17/200
6/6 - 0s - 6ms/step - accuracy: 0.1250 - loss: 1.5053
Epoch 18/200
6/6 - 0s - 7ms/step - accuracy: 0.1333 - loss: 1.4814
Epoch 19/200
6/6 - 0s - 6ms/step - accuracy: 0.1333 - loss: 1.4583
Epoch 20/200
6/6 - 0s - 5ms/step - accuracy: 0.1333 - loss: 1.4375
Epoch 21/200
6/6 - 0s - 6ms/step - accuracy: 0.1333 - loss: 1.4160
Epoch 22/200
6/6 - 0s - 5ms/step - accuracy: 0.1333 - loss: 1.3957
Epoch 23/200
6/6 - 0s - 5ms/step - accuracy: 0.1333 - loss: 1.3768
Epoch 24/200
6/6 - 0s - 5ms/step - accuracy: 0.1500 - loss: 1.3584
Epoch 25/200
6/6 - 0s - 5ms/step - accuracy: 0.1417 - loss: 1.3411
Epoch 26/200
6/6 - 0s - 6ms/step - accuracy: 0.1667 - loss: 1.3245
Epoch 27/200
6/6 - 0s - 7ms/step - accuracy: 0.1667 - loss: 1.3084
Epoch 28/200
6/6 - 0s - 8ms/step - accuracy: 0.1750 - loss: 1.2930
Epoch 29/200
6/6 - 0s - 5ms/step - accuracy: 0.1750 - loss: 1.2787
Epoch 30/200
6/6 - 0s - 5ms/step - accuracy: 0.1750 - loss: 1.2649
Epoch 31/200
6/6 - 0s - 5ms/step - accuracy: 0.1750 - loss: 1.2526
Epoch 32/200
6/6 - 0s - 5ms/step - accuracy: 0.1833 - loss: 1.2403
Epoch 33/200
6/6 - 0s - 5ms/step - accuracy: 0.1917 - loss: 1.2284
Epoch 34/200
6/6 - 0s - 5ms/step - accuracy: 0.1917 - loss: 1.2171
Epoch 35/200
6/6 - 0s - 6ms/step - accuracy: 0.1917 - loss: 1.2062
Epoch 36/200
6/6 - 0s - 5ms/step - accuracy: 0.2000 - loss: 1.1961
Epoch 37/200
6/6 - 0s - 5ms/step - accuracy: 0.2000 - loss: 1.1860
Epoch 38/200
6/6 - 0s - 7ms/step - accuracy: 0.2083 - loss: 1.1769
Epoch 39/200
6/6 - 0s - 7ms/step - accuracy: 0.2167 - loss: 1.1674
Epoch 40/200
6/6 - 0s - 5ms/step - accuracy: 0.2167 - loss: 1.1590
Epoch 41/200
6/6 - 0s - 5ms/step - accuracy: 0.2083 - loss: 1.1503
Epoch 42/200
6/6 - 0s - 5ms/step - accuracy: 0.2083 - loss: 1.1422
Epoch 43/200
6/6 - 0s - 5ms/step - accuracy: 0.2167 - loss: 1.1344
Epoch 44/200
6/6 - 0s - 5ms/step - accuracy: 0.2250 - loss: 1.1264
Epoch 45/200
6/6 - 0s - 5ms/step - accuracy: 0.2417 - loss: 1.1186
Epoch 46/200
6/6 - 0s - 5ms/step - accuracy: 0.2833 - loss: 1.1110
Epoch 47/200
6/6 - 0s - 5ms/step - accuracy: 0.3000 - loss: 1.1032
Epoch 48/200
6/6 - 0s - 6ms/step - accuracy: 0.3417 - loss: 1.0952
Epoch 49/200
6/6 - 0s - 5ms/step - accuracy: 0.3667 - loss: 1.0871
Epoch 50/200
6/6 - 0s - 6ms/step - accuracy: 0.4000 - loss: 1.0791
Epoch 51/200
6/6 - 0s - 5ms/step - accuracy: 0.4833 - loss: 1.0710
Epoch 52/200
6/6 - 0s - 5ms/step - accuracy: 0.5167 - loss: 1.0625
Epoch 53/200
6/6 - 0s - 5ms/step - accuracy: 0.5250 - loss: 1.0541
Epoch 54/200
6/6 - 0s - 5ms/step - accuracy: 0.5333 - loss: 1.0458
Epoch 55/200
6/6 - 0s - 7ms/step - accuracy: 0.5417 - loss: 1.0376
Epoch 56/200
6/6 - 0s - 6ms/step - accuracy: 0.5500 - loss: 1.0292
Epoch 57/200
6/6 - 0s - 6ms/step - accuracy: 0.5667 - loss: 1.0206
Epoch 58/200
6/6 - 0s - 6ms/step - accuracy: 0.5667 - loss: 1.0124
Epoch 59/200
6/6 - 0s - 6ms/step - accuracy: 0.5833 - loss: 1.0038
Epoch 60/200
6/6 - 0s - 8ms/step - accuracy: 0.5917 - loss: 0.9953
Epoch 61/200
6/6 - 0s - 8ms/step - accuracy: 0.6000 - loss: 0.9869
Epoch 62/200
6/6 - 0s - 8ms/step - accuracy: 0.6083 - loss: 0.9781
Epoch 63/200
6/6 - 0s - 8ms/step - accuracy: 0.6167 - loss: 0.9694
Epoch 64/200
6/6 - 0s - 9ms/step - accuracy: 0.6167 - loss: 0.9607
Epoch 65/200
6/6 - 0s - 7ms/step - accuracy: 0.6417 - loss: 0.9521
Epoch 66/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.9436
Epoch 67/200
6/6 - 0s - 7ms/step - accuracy: 0.6417 - loss: 0.9351
Epoch 68/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.9268
Epoch 69/200
6/6 - 0s - 7ms/step - accuracy: 0.6417 - loss: 0.9184
Epoch 70/200
6/6 - 0s - 7ms/step - accuracy: 0.6417 - loss: 0.9101
Epoch 71/200
6/6 - 0s - 7ms/step - accuracy: 0.6417 - loss: 0.9017
Epoch 72/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.8934
Epoch 73/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.8855
Epoch 74/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.8773
Epoch 75/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.8692
Epoch 76/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.8613
Epoch 77/200
6/6 - 0s - 6ms/step - accuracy: 0.6417 - loss: 0.8534
Epoch 78/200
6/6 - 0s - 7ms/step - accuracy: 0.6417 - loss: 0.8456
Epoch 79/200
6/6 - 0s - 7ms/step - accuracy: 0.6417 - loss: 0.8380
Epoch 80/200
6/6 - 0s - 7ms/step - accuracy: 0.6500 - loss: 0.8303
Epoch 81/200
6/6 - 0s - 6ms/step - accuracy: 0.6500 - loss: 0.8228
Epoch 82/200
6/6 - 0s - 7ms/step - accuracy: 0.6500 - loss: 0.8155
Epoch 83/200
6/6 - 0s - 7ms/step - accuracy: 0.6500 - loss: 0.8085
Epoch 84/200
6/6 - 0s - 7ms/step - accuracy: 0.6500 - loss: 0.8018
Epoch 85/200
6/6 - 0s - 7ms/step - accuracy: 0.6500 - loss: 0.7949
Epoch 86/200
6/6 - 0s - 6ms/step - accuracy: 0.6500 - loss: 0.7878
Epoch 87/200
6/6 - 0s - 7ms/step - accuracy: 0.6500 - loss: 0.7810
Epoch 88/200
6/6 - 0s - 6ms/step - accuracy: 0.6500 - loss: 0.7748
Epoch 89/200
6/6 - 0s - 7ms/step - accuracy: 0.6500 - loss: 0.7686
Epoch 90/200
6/6 - 0s - 5ms/step - accuracy: 0.6500 - loss: 0.7627
Epoch 91/200
6/6 - 0s - 5ms/step - accuracy: 0.6500 - loss: 0.7568
Epoch 92/200
6/6 - 0s - 5ms/step - accuracy: 0.6583 - loss: 0.7511
Epoch 93/200
6/6 - 0s - 5ms/step - accuracy: 0.6583 - loss: 0.7456
Epoch 94/200
6/6 - 0s - 6ms/step - accuracy: 0.6583 - loss: 0.7403
Epoch 95/200
6/6 - 0s - 6ms/step - accuracy: 0.6583 - loss: 0.7350
Epoch 96/200
6/6 - 0s - 6ms/step - accuracy: 0.6583 - loss: 0.7299
Epoch 97/200
6/6 - 0s - 6ms/step - accuracy: 0.6583 - loss: 0.7249
Epoch 98/200
6/6 - 0s - 6ms/step - accuracy: 0.6667 - loss: 0.7201
Epoch 99/200
6/6 - 0s - 5ms/step - accuracy: 0.6833 - loss: 0.7152
Epoch 100/200
6/6 - 0s - 7ms/step - accuracy: 0.6833 - loss: 0.7106
Epoch 101/200
6/6 - 0s - 6ms/step - accuracy: 0.6833 - loss: 0.7061
Epoch 102/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.7017
Epoch 103/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6975
Epoch 104/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6934
Epoch 105/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6893
Epoch 106/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6854
Epoch 107/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6816
Epoch 108/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6781
Epoch 109/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6747
Epoch 110/200
6/6 - 0s - 5ms/step - accuracy: 0.6917 - loss: 0.6713
Epoch 111/200
6/6 - 0s - 5ms/step - accuracy: 0.7000 - loss: 0.6679
Epoch 112/200
6/6 - 0s - 7ms/step - accuracy: 0.7000 - loss: 0.6646
Epoch 113/200
6/6 - 0s - 5ms/step - accuracy: 0.7167 - loss: 0.6612
Epoch 114/200
6/6 - 0s - 6ms/step - accuracy: 0.7167 - loss: 0.6580
Epoch 115/200
6/6 - 0s - 6ms/step - accuracy: 0.7167 - loss: 0.6549
Epoch 116/200
6/6 - 0s - 7ms/step - accuracy: 0.7167 - loss: 0.6518
Epoch 117/200
6/6 - 0s - 5ms/step - accuracy: 0.7250 - loss: 0.6487
Epoch 118/200
6/6 - 0s - 6ms/step - accuracy: 0.7250 - loss: 0.6456
Epoch 119/200
6/6 - 0s - 5ms/step - accuracy: 0.7250 - loss: 0.6426
Epoch 120/200
6/6 - 0s - 8ms/step - accuracy: 0.7250 - loss: 0.6396
Epoch 121/200
6/6 - 0s - 6ms/step - accuracy: 0.7333 - loss: 0.6367
Epoch 122/200
6/6 - 0s - 5ms/step - accuracy: 0.7333 - loss: 0.6337
Epoch 123/200
6/6 - 0s - 7ms/step - accuracy: 0.7417 - loss: 0.6308
Epoch 124/200
6/6 - 0s - 7ms/step - accuracy: 0.7417 - loss: 0.6279
Epoch 125/200
6/6 - 0s - 6ms/step - accuracy: 0.7417 - loss: 0.6249
Epoch 126/200
6/6 - 0s - 5ms/step - accuracy: 0.7500 - loss: 0.6220
Epoch 127/200
6/6 - 0s - 5ms/step - accuracy: 0.7500 - loss: 0.6193
Epoch 128/200
6/6 - 0s - 5ms/step - accuracy: 0.7500 - loss: 0.6165
Epoch 129/200
6/6 - 0s - 5ms/step - accuracy: 0.7583 - loss: 0.6137
Epoch 130/200
6/6 - 0s - 6ms/step - accuracy: 0.7667 - loss: 0.6109
Epoch 131/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.6082
Epoch 132/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.6055
Epoch 133/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.6028
Epoch 134/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.6002
Epoch 135/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.5975
Epoch 136/200
6/6 - 0s - 7ms/step - accuracy: 0.7750 - loss: 0.5947
Epoch 137/200
6/6 - 0s - 7ms/step - accuracy: 0.7750 - loss: 0.5922
Epoch 138/200
6/6 - 0s - 7ms/step - accuracy: 0.7750 - loss: 0.5895
Epoch 139/200
6/6 - 0s - 7ms/step - accuracy: 0.7833 - loss: 0.5870
Epoch 140/200
6/6 - 0s - 5ms/step - accuracy: 0.7750 - loss: 0.5846
Epoch 141/200
6/6 - 0s - 5ms/step - accuracy: 0.7833 - loss: 0.5821
Epoch 142/200
6/6 - 0s - 7ms/step - accuracy: 0.7833 - loss: 0.5799
Epoch 143/200
6/6 - 0s - 7ms/step - accuracy: 0.7833 - loss: 0.5773
Epoch 144/200
6/6 - 0s - 7ms/step - accuracy: 0.7833 - loss: 0.5750
Epoch 145/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.5726
Epoch 146/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.5701
Epoch 147/200
6/6 - 0s - 7ms/step - accuracy: 0.7917 - loss: 0.5676
Epoch 148/200
6/6 - 0s - 6ms/step - accuracy: 0.7917 - loss: 0.5652
Epoch 149/200
6/6 - 0s - 5ms/step - accuracy: 0.7917 - loss: 0.5627
Epoch 150/200
6/6 - 0s - 6ms/step - accuracy: 0.7917 - loss: 0.5602
Epoch 151/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.5578
Epoch 152/200
6/6 - 0s - 6ms/step - accuracy: 0.7833 - loss: 0.5553
Epoch 153/200
6/6 - 0s - 6ms/step - accuracy: 0.7917 - loss: 0.5531
Epoch 154/200
6/6 - 0s - 7ms/step - accuracy: 0.7917 - loss: 0.5507
Epoch 155/200
6/6 - 0s - 7ms/step - accuracy: 0.7917 - loss: 0.5483
Epoch 156/200
6/6 - 0s - 6ms/step - accuracy: 0.8000 - loss: 0.5460
Epoch 157/200
6/6 - 0s - 6ms/step - accuracy: 0.8000 - loss: 0.5436
Epoch 158/200
6/6 - 0s - 6ms/step - accuracy: 0.8083 - loss: 0.5413
Epoch 159/200
6/6 - 0s - 6ms/step - accuracy: 0.8083 - loss: 0.5388
Epoch 160/200
6/6 - 0s - 6ms/step - accuracy: 0.8083 - loss: 0.5365
Epoch 161/200
6/6 - 0s - 5ms/step - accuracy: 0.8083 - loss: 0.5342
Epoch 162/200
6/6 - 0s - 5ms/step - accuracy: 0.8083 - loss: 0.5320
Epoch 163/200
6/6 - 0s - 6ms/step - accuracy: 0.8083 - loss: 0.5297
Epoch 164/200
6/6 - 0s - 7ms/step - accuracy: 0.8083 - loss: 0.5274
Epoch 165/200
6/6 - 0s - 6ms/step - accuracy: 0.8083 - loss: 0.5251
Epoch 166/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.5227
Epoch 167/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.5205
Epoch 168/200
6/6 - 0s - 5ms/step - accuracy: 0.8167 - loss: 0.5183
Epoch 169/200
6/6 - 0s - 5ms/step - accuracy: 0.8250 - loss: 0.5161
Epoch 170/200
6/6 - 0s - 5ms/step - accuracy: 0.8250 - loss: 0.5139
Epoch 171/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.5118
Epoch 172/200
6/6 - 0s - 7ms/step - accuracy: 0.8250 - loss: 0.5097
Epoch 173/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.5075
Epoch 174/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.5054
Epoch 175/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.5033
Epoch 176/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.5013
Epoch 177/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.4994
Epoch 178/200
6/6 - 0s - 6ms/step - accuracy: 0.8250 - loss: 0.4974
Epoch 179/200
6/6 - 0s - 7ms/step - accuracy: 0.8333 - loss: 0.4954
Epoch 180/200
6/6 - 0s - 7ms/step - accuracy: 0.8333 - loss: 0.4935
Epoch 181/200
6/6 - 0s - 7ms/step - accuracy: 0.8333 - loss: 0.4916
Epoch 182/200
6/6 - 0s - 7ms/step - accuracy: 0.8333 - loss: 0.4896
Epoch 183/200
6/6 - 0s - 7ms/step - accuracy: 0.8333 - loss: 0.4877
Epoch 184/200
6/6 - 0s - 7ms/step - accuracy: 0.8333 - loss: 0.4858
Epoch 185/200
6/6 - 0s - 5ms/step - accuracy: 0.8333 - loss: 0.4839
Epoch 186/200
6/6 - 0s - 5ms/step - accuracy: 0.8417 - loss: 0.4821
Epoch 187/200
6/6 - 0s - 5ms/step - accuracy: 0.8500 - loss: 0.4803
Epoch 188/200
6/6 - 0s - 5ms/step - accuracy: 0.8417 - loss: 0.4786
Epoch 189/200
6/6 - 0s - 6ms/step - accuracy: 0.8417 - loss: 0.4768
Epoch 190/200
6/6 - 0s - 5ms/step - accuracy: 0.8417 - loss: 0.4751
Epoch 191/200
6/6 - 0s - 5ms/step - accuracy: 0.8417 - loss: 0.4734
Epoch 192/200
6/6 - 0s - 6ms/step - accuracy: 0.8417 - loss: 0.4716
Epoch 193/200
6/6 - 0s - 6ms/step - accuracy: 0.8417 - loss: 0.4699
Epoch 194/200
6/6 - 0s - 7ms/step - accuracy: 0.8417 - loss: 0.4684
Epoch 195/200
6/6 - 0s - 7ms/step - accuracy: 0.8417 - loss: 0.4667
Epoch 196/200
6/6 - 0s - 7ms/step - accuracy: 0.8417 - loss: 0.4651
Epoch 197/200
6/6 - 0s - 8ms/step - accuracy: 0.8417 - loss: 0.4637
Epoch 198/200
6/6 - 0s - 7ms/step - accuracy: 0.8417 - loss: 0.4620
Epoch 199/200
6/6 - 0s - 9ms/step - accuracy: 0.8500 - loss: 0.4604
Epoch 200/200
6/6 - 0s - 6ms/step - accuracy: 0.8500 - loss: 0.4588
plot(history) +
  ggtitle("Training a neural network based classifier on the iris data set") +
  theme_bw()

Evaluate Network Performance

The final performance can be obtained like so:

perf <- model |> evaluate(x_test, y_test)
1/1 - 0s - 323ms/step - accuracy: 0.8667 - loss: 0.4595
print(perf)
$accuracy
[1] 0.8666667

$loss
[1] 0.4595394
classes <- iris |> 
  as_tibble() |> 
  pull(Species) |> 
  unique()

y_pred <- model |> 
  predict(x_test)  |> 
  op_argmax(axis = -1) |>
  as.numeric() - 1
1/1 - 0s - 118ms/step
y_true  <- test |>  
  select(class_label) |>  
  unlist() |>  
  as.numeric()

tibble(
  y_true  = classes[y_true + 1], 
  y_pred  = classes[y_pred + 1],
  Correct = factor(ifelse(y_true == y_pred, "Yes", "No"))
) |>
  ggplot(aes(x = y_true, y = y_pred, colour = Correct)) +
  geom_jitter() +
  theme_bw() +
  ggtitle(label = "Classification Performance of Artificial Neural Network",
          subtitle = str_c("Accuracy = ",round(perf$accuracy,3)*100,"%")) +
  xlab(label = "True iris class") +
  ylab(label = "Predicted iris class")

library(gmodels)

CrossTable(y_pred, y_true,
           prop.chisq = FALSE, prop.t = FALSE, prop.r = FALSE,
           dnn = c('predicted', 'actual'))

 
   Cell Contents
|-------------------------|
|                       N |
|           N / Col Total |
|-------------------------|

 
Total Observations in Table:  30 

 
             | actual 
   predicted |         0 |         1 |         2 | Row Total | 
-------------|-----------|-----------|-----------|-----------|
           0 |        10 |         0 |         0 |        10 | 
             |     1.000 |     0.000 |     0.000 |           | 
-------------|-----------|-----------|-----------|-----------|
           1 |         0 |         5 |         0 |         5 | 
             |     0.000 |     0.556 |     0.000 |           | 
-------------|-----------|-----------|-----------|-----------|
           2 |         0 |         4 |        11 |        15 | 
             |     0.000 |     0.444 |     1.000 |           | 
-------------|-----------|-----------|-----------|-----------|
Column Total |        10 |         9 |        11 |        30 | 
             |     0.333 |     0.300 |     0.367 |           | 
-------------|-----------|-----------|-----------|-----------|

 

Conclusion

I hope this illustrated just how easy it is to get started building artificial neural network using Keras and TensorFlow in R. With relative ease, we created a 3-class predictor with an accuracy of 100%. This was a basic minimal example. The network can be expanded to create Deep Learning networks and also the entire TensorFlow API is available.

Enjoy and Happy Learning!

Leon

Thanks again Leon, this was awsome!!!