For my CS570 class this semester on Deep Learning, I prepared some figures which compare the logistic loss to the zero-one loss in binary classification. I also wanted to show the loss functions for multi-class classification. This post explains how I did that in python using numpy, pandas, and plotnine. First I define the loss functions in terms of the real-valued prediction scores f and the label y (either -1 or 1),

import numpy as np
loss_dict = {
    "logistic":lambda f, y: np.log(1+np.exp(-y*f)),
    "zero-one":lambda f, y: np.where(f>0, 1, -1)!=y,
}

Then we compute those loss functions for both labels, on a grid of predicted scores from -5 to 5,

import pandas as pd
pred_lim = 5
pred_grid = np.linspace(-pred_lim, pred_lim)
loss_df_list = []
for loss_name, loss_fun in loss_dict.items():
    for y in -1, 1:
        loss_df_list.append(pd.DataFrame({
            "loss_name":loss_name,
            "loss_value":loss_fun(pred_grid, y),
            "predicted_score":pred_grid,
            "label":y,
        }))
loss_df = pd.concat(loss_df_list)
loss_df
##    loss_name  loss_value  predicted_score  label
## 0   logistic    0.006715        -5.000000     -1
## 1   logistic    0.008229        -4.795918     -1
## 2   logistic    0.010083        -4.591837     -1
## 3   logistic    0.012352        -4.387755     -1
## 4   logistic    0.015127        -4.183673     -1
## ..       ...         ...              ...    ...
## 45  zero-one    0.000000         4.183673      1
## 46  zero-one    0.000000         4.387755      1
## 47  zero-one    0.000000         4.591837      1
## 48  zero-one    0.000000         4.795918      1
## 49  zero-one    0.000000         5.000000      1
## 
## [200 rows x 4 columns]

Then we plot these loss values with one panel for each label,

import plotnine as p9
def gg_binary(x):
    return p9.ggplot()+\
        p9.facet_grid(". ~ label", labeller="label_both")+\
        p9.scale_x_continuous(
            breaks=np.arange(-5, 7, 0.5 if x is "pred_prob1" else 2))+\
        p9.theme_bw()+\
        p9.theme(subplots_adjust={'right': 0.7, "bottom":0.2})+\
        p9.theme(figure_size=(4.5,2))+\
        p9.geom_point(
            p9.aes(
                x=x,
                y="loss_value",
                color="loss_name",
                ),
            data=loss_df)
show(gg_binary("predicted_score"), "binary-loss-scores")

plot of binary-loss-scores

Next we can plot the loss as a function of predicted probability of class 1,

loss_df["pred_prob1"] = 1/(1+np.exp(-loss_df.predicted_score))
show(gg_binary("pred_prob1"), "binary-loss-prob")

plot of binary-loss-prob

How to generalize this plot to more classes than two? We can plot the loss functions for three classes on the probability simplex, projected onto the cartesian plotting plane as an equilateral triangle. First we compute the vertices of the equilateral triangle,

xmax = 1.0
upper_x = xmax/2.0
ymax = np.sqrt(xmax - upper_x**2)
vertices_mat = np.array([
    [xmax, 0, 1],
    [upper_x, ymax, 1],
    [0,0,1]
])
vertices_mat
## array([[1.       , 0.       , 1.       ],
##        [0.5      , 0.8660254, 1.       ],
##        [0.       , 0.       , 1.       ]])

The first two columns of the matrix above represent the vertices of the equilateral triangle. Each point inside this triangle represents a triple of probability values on the simplex, which sum to one and are all at least zero. To plot the loss function for each probability triple in the simplex, we can make a heat map by first computing a grid of (x,y) values:

def make_grid(mat, n_grid = 200):
    nrow, ncol = mat.shape
    assert ncol == 2
    mesh_args = mat.apply(
        lambda x: np.linspace(min(x),max(x), n_grid), axis=0)
    mesh_tup = np.meshgrid(*[mesh_args[x] for x in mesh_args])
    mesh_vectors = [v.flatten() for v in mesh_tup]
    return pd.DataFrame(dict(zip(mesh_args,mesh_vectors)))
simplex_grid = make_grid(pd.DataFrame({
    "x":np.linspace(0,xmax),
    "y":np.linspace(0,ymax)
}))
simplex_grid
##               x         y
## 0      0.000000  0.000000
## 1      0.005025  0.000000
## 2      0.010050  0.000000
## 3      0.015075  0.000000
## 4      0.020101  0.000000
## ...         ...       ...
## 39995  0.979899  0.866025
## 39996  0.984925  0.866025
## 39997  0.989950  0.866025
## 39998  0.994975  0.866025
## 39999  1.000000  0.866025
## 
## [40000 rows x 2 columns]

Now which of these rows above falls within the equilateral triangle, and is therefore a valid probability triple in the simplex? We need a mapping from these (x,y) coordinates to the probability triple coordinates, which we can get by solving a linear system: vertices_mat * A = I. The matrix A is the linear transformation which converts x,y coordinates to probability triple coordinates (identity matrix I represents unit vectors which are the vertices of the simplex in probability coordinates). To solve for A we just need to matrix multiply both sides by the inverse,

to_prob_mat = np.linalg.inv(vertices_mat)
to_prob_mat
## array([[ 1.        ,  0.        , -1.        ],
##        [-0.57735027,  1.15470054, -0.57735027],
##        [ 0.        ,  0.        ,  1.        ]])

We then can convert the x,y coordinates of the grid to probability coordinates:

simplex_grid_xy = np.column_stack(
    [simplex_grid, np.repeat(1, simplex_grid.shape[0])])
simplex_grid_prob = np.matmul(simplex_grid_xy, to_prob_mat)
simplex_grid_prob
## array([[ 0.        ,  0.        ,  1.        ],
##        [ 0.00502513,  0.        ,  0.99497487],
##        [ 0.01005025,  0.        ,  0.98994975],
##        ...,
##        [ 0.48994975,  1.        , -0.48994975],
##        [ 0.49497487,  1.        , -0.49497487],
##        [ 0.5       ,  1.        , -0.5       ]])

Then we exclude the rows with any negative probability values,

keep = simplex_grid_prob.min(axis=1) >= 0
keep_grid = pd.concat([
    pd.DataFrame(simplex_grid_prob), simplex_grid
], axis=1)[keep]
keep_grid
##               0         1         2         x         y
## 0      0.000000  0.000000  1.000000  0.000000  0.000000
## 1      0.005025  0.000000  0.994975  0.005025  0.000000
## 2      0.010050  0.000000  0.989950  0.010050  0.000000
## 3      0.015075  0.000000  0.984925  0.015075  0.000000
## 4      0.020101  0.000000  0.979899  0.020101  0.000000
## ...         ...       ...       ...       ...       ...
## 39300  0.010050  0.984925  0.005025  0.502513  0.852970
## 39301  0.015075  0.984925  0.000000  0.507538  0.852970
## 39499  0.002513  0.989950  0.007538  0.497487  0.857322
## 39500  0.007538  0.989950  0.002513  0.502513  0.857322
## 39700  0.005025  0.994975  0.000000  0.502513  0.861674
## 
## [19897 rows x 5 columns]

Next we compute the logistic loss for each one of those grid points, and for each of the three labels,

def get_loss_df(loss_fun):
    loss_simplex_list = []
    for label in range(3):
        loss_only = pd.DataFrame({
            "loss":loss_fun(label),
            "label":label,
        })
        loss_grid = pd.concat([keep_grid.reset_index(), loss_only], axis=1)
        loss_simplex_list.append(loss_grid)
    return pd.concat(loss_simplex_list)
def logistic_loss(label, loss_max = 5):
    label_prob = keep_grid[label]
    loss_vec = np.log(1/label_prob)
    # threshold loss for visualization purposes, to avoid saturation.
    return np.where(loss_vec<loss_max, loss_vec, loss_max)
logistic_df = get_loss_df(logistic_loss)
logistic_df
##        index         0         1         2         x         y      loss  label
## 0          0  0.000000  0.000000  1.000000  0.000000  0.000000  5.000000      0
## 1          1  0.005025  0.000000  0.994975  0.005025  0.000000  5.000000      0
## 2          2  0.010050  0.000000  0.989950  0.010050  0.000000  4.600158      0
## 3          3  0.015075  0.000000  0.984925  0.015075  0.000000  4.194693      0
## 4          4  0.020101  0.000000  0.979899  0.020101  0.000000  3.907010      0
## ...      ...       ...       ...       ...       ...       ...       ...    ...
## 19892  39300  0.010050  0.984925  0.005025  0.502513  0.852970  5.000000      2
## 19893  39301  0.015075  0.984925  0.000000  0.507538  0.852970  5.000000      2
## 19894  39499  0.002513  0.989950  0.007538  0.497487  0.857322  4.887840      2
## 19895  39500  0.007538  0.989950  0.002513  0.502513  0.857322  5.000000      2
## 19896  39700  0.005025  0.994975  0.000000  0.502513  0.861674  5.000000      2
## 
## [59691 rows x 8 columns]

Finally, we plot the loss values as a heatmap on the equilateral triangle which represents the probability simplex,

def gg_loss(loss_name, grid_loss_df, breaks=[0,0.5,1]):
    return p9.ggplot()+\
        p9.ggtitle(loss_name+" loss on 3-simplex")+\
        p9.facet_grid(". ~ label", labeller="label_both")+\
        p9.geom_tile(
            p9.aes(
                x="x",
                y="y",
                fill="loss",
                ),
            data=grid_loss_df)+\
        p9.scale_fill_gradient(
            low="white",
            high="red")+\
        p9.scale_x_continuous(
            breaks=breaks)+\
        p9.scale_y_continuous(
            breaks=breaks)+\
        p9.coord_equal()
show(gg_loss("Logistic", logistic_df), "multi-logistic")

plot of multi-logistic

Next, we can compute the zero-one loss in a similar fashion,

def zero_one_loss(label):
    loss_vec = np.array(keep_grid.loc[:,[0,1,2]]).argmax(axis=1) != label
    return np.where(loss_vec, 1, 0)
zero_one_df = get_loss_df(zero_one_loss)
zero_one_df
##        index         0         1         2         x         y  loss  label
## 0          0  0.000000  0.000000  1.000000  0.000000  0.000000     1      0
## 1          1  0.005025  0.000000  0.994975  0.005025  0.000000     1      0
## 2          2  0.010050  0.000000  0.989950  0.010050  0.000000     1      0
## 3          3  0.015075  0.000000  0.984925  0.015075  0.000000     1      0
## 4          4  0.020101  0.000000  0.979899  0.020101  0.000000     1      0
## ...      ...       ...       ...       ...       ...       ...   ...    ...
## 19892  39300  0.010050  0.984925  0.005025  0.502513  0.852970     1      2
## 19893  39301  0.015075  0.984925  0.000000  0.507538  0.852970     1      2
## 19894  39499  0.002513  0.989950  0.007538  0.497487  0.857322     1      2
## 19895  39500  0.007538  0.989950  0.002513  0.502513  0.857322     1      2
## 19896  39700  0.005025  0.994975  0.000000  0.502513  0.861674     1      2
## 
## [59691 rows x 8 columns]
show(gg_loss("Zero-one", zero_one_df), "multi-zero-one")

plot of multi-zero-one

References:

  • https://en.wikipedia.org/wiki/Simplex#The_standard_simplex