Writing code for data manipulation is an important part of any machine learning project or research paper. To me data manipulation is a very general class of operations which involves anything that converts data from one format to another. Data manipulation is useful for (1) pre-processing in order to get the right format for machine learning algorithms, and (2) post-processing in order to get the right format for visualization using tables or figures.

Python pandas and R data.table for plotting keras model fit metrics

When training machine learning models such as neural networks, it is important to monitor various loss/accuracy metrics, and plot them as a function of the regularization hyper-parameter. For example in my screencasts about neural networks using keras in R we plot subtrain/validation loss/accuracy versus number of epochs, in order to see when the neural network starts to overfit. In this section I comment on how I coded this in R, and how I translated it into python.

The first step is to use the keras package in R to declare and fit a neural network model, which I did with the following code (note .... means a bunch of irrelevant code has been omitted):

model <- keras::keras_model_sequential() %>% ....
history <- model %>% keras::fit(....)

The fit function returns a named list of numeric vectors (names are metrics such as loss, val_loss, acc, val_acc for logistic loss and proportion accuracy with respect to train or validation sets), which we then convert to a data table with an additional epoch column via:

history.wide <- do.call(data.table::data.table, history$metrics)
history.wide[, epoch := 1:.N]

The resulting data table looks like

      val_loss   val_acc       loss       acc epoch
  1: 0.3261010 0.9010870 0.49783516 0.7798913     1
  2: 0.2535547 0.9114130 0.26578763 0.9228261     2
  3: 0.2392319 0.9163043 0.21651624 0.9326087     3
  4: 0.2274921 0.9206522 0.19466753 0.9304348     4
  5: 0.2212409 0.9255435 0.18036880 0.9336957     5
 ---                                               
 96: 0.3007694 0.9255435 0.05798265 0.9831522    96
 97: 0.3010736 0.9244565 0.05763476 0.9831522    97
 98: 0.3025759 0.9260870 0.05615792 0.9820652    98
 99: 0.3045697 0.9260870 0.05674289 0.9836957    99
100: 0.3036924 0.9250000 0.05580203 0.9820652   100

Doing the same computations in python looks like

import tensorflow as tf
import pandas as pd
model = tf.keras.Model(....)
history = model.fit(....)
history_wide = pd.DataFrame(history.history)
history_wide["epoch"] = np.arange(len(history_wide.index))+1

The python model.fit method above returns a dictionary with keys for metric names (val_loss etc) and values which are numpy arrays, which are combined into a DataFrame with one row for each epoch.

After having created a wide data table (with different metrics in different columns), the next step to visualizing these data with the grammar of graphics (ggplots) is to reshape the data into tall format, e.g.

     epoch prefix metric     value
  1:     1   val_   loss 0.3261010
  2:     2   val_   loss 0.2535547
  3:     3   val_   loss 0.2392319
  4:     4   val_   loss 0.2274921
  5:     5   val_   loss 0.2212409
 ---                              
396:    96           acc 0.9831522
397:    97           acc 0.9831522
398:    98           acc 0.9820652
399:    99           acc 0.9836957
400:   100           acc 0.9820652

In the tall data table above, the original four metric columns have been reshaped into a single value column. There is also a copy of the original epoch column, and two new columns which indicate the set and metric. To accomplish that conversion in R I used the following function from my nc package:

history.tall.sets <- nc::capture_melt_single(
  history.wide,
  set="val_|",
  metric="loss|acc")

The function call above performs the reshape operation on all of the columns from the first argument history.wide which match the regex provided in the other arguments, which are pasted together to form the final regex which will be used for matching. For more info see my recently submitted article. and screencasts about the new functions in nc for data reshaping using regular expressions.

To accomplish something similar in the python code below we first use the melt function to get a tall version of the data, then we use extract and concat functions to get the desired output:

history_tall = pd.melt(history_wide, id_vars="epoch")
history_var_info = history_tall["variable"].str.extract(
    "(?P<prefix>val_|)(?P<metric>.*)")
history_tall_sets = pd.concat(
    [history_tall, history_var_info],
    axis=1)

Translating the code above back to R results in the code below, which is essentially what nc::capture_melt_single does under the hood:

history.tall <- data.table::melt(history.wide, id.vars="epoch")
history.var.info <- nc::capture_first_vec(
  history.tall$variable, prefix="val_|", metric=".*")
history.tall.sets <- data.table::data.table(
  history.tall, history.var.info)

The next step is to add a variable for the set name that we want to display in the plot,

history.tall.sets[, set := ifelse(
  prefix=="val_", "validation", "subtrain")]
history_tall_sets["set"] = history_tall_sets["prefix"].apply(
    lambda x: "validation" if x == "val_" else "subtrain")

Finally we plot these data using the R code

library(ggplot2)
ggplot()+
  geom_line(aes(
    x=epoch, y=value, color=set),
    data=history.tall.sets)+
  theme_bw()+
  theme(panel.spacing=grid::unit(0, "lines"))+
  facet_grid("metric")
ggsave("5-acc-loss.png", width=5, height=5)

Or the equivalent python code,

import plotnine as p9
gg = p9.ggplot(
    history_tall_sets,
    p9.aes(x="epoch", y="value", color="set"))+\
    p9.geom_line()+\
    p9.theme_bw()+\
    p9.facet_grid("metric ~ .", scales="free")+\
    p9.theme(
        facet_spacing={'right': 0.75}, #due to bug in legend.
        panel_spacing=0)
gg.save("5-acc-loss.png", width=5, height=5)

Note that the plotnine python module seems to be the current best implementation of ggplots, but I still prefer ggplots in R, especially because it is so easy to add textual labels for these kind of plots via my directlabels package.

Comparison with datatable python module

The R data.table developers have created a port of their highly efficient C code to python, in the datatable module. However I wasn’t able to use it do to the same computations as above, because it does not yet support the melt operation (wide to tall data reshaping). I posted an issue because the presence/absence of this key feature was not mentioned in the documentation. In that issue there is a link to a very detailed comparison of features provided by data.table and dplyr/tidyverse in R, which I would recommend reading for anyone doing data manipulation in R (and especially my students). Another interesting comparison shows that for some big data sets, reading CSV using datatable and then converting to pandas can actually be faster than directly reading data in pandas.