Plotting the probability simplex
For my CS570 class this semester on Deep
Learning, I prepared
some figures which compare the logistic loss to the zero-one loss in
binary classification. I also wanted to show the loss functions for
multi-class classification. This post explains how I did that in
python using numpy, pandas, and plotnine. First I define the loss
functions in terms of the real-valued prediction scores f
and the
label y
(either -1 or 1),
import numpy as np
loss_dict = {
"logistic":lambda f, y: np.log(1+np.exp(-y*f)),
"zero-one":lambda f, y: np.where(f>0, 1, -1)!=y,
}
Then we compute those loss functions for both labels, on a grid of predicted scores from -5 to 5,
import pandas as pd
pred_lim = 5
pred_grid = np.linspace(-pred_lim, pred_lim)
loss_df_list = []
for loss_name, loss_fun in loss_dict.items():
for y in -1, 1:
loss_df_list.append(pd.DataFrame({
"loss_name":loss_name,
"loss_value":loss_fun(pred_grid, y),
"predicted_score":pred_grid,
"label":y,
}))
loss_df = pd.concat(loss_df_list)
loss_df
## loss_name loss_value predicted_score label
## 0 logistic 0.006715 -5.000000 -1
## 1 logistic 0.008229 -4.795918 -1
## 2 logistic 0.010083 -4.591837 -1
## 3 logistic 0.012352 -4.387755 -1
## 4 logistic 0.015127 -4.183673 -1
## .. ... ... ... ...
## 45 zero-one 0.000000 4.183673 1
## 46 zero-one 0.000000 4.387755 1
## 47 zero-one 0.000000 4.591837 1
## 48 zero-one 0.000000 4.795918 1
## 49 zero-one 0.000000 5.000000 1
##
## [200 rows x 4 columns]
Then we plot these loss values with one panel for each label,
import plotnine as p9
def gg_binary(x):
return p9.ggplot()+\
p9.facet_grid(". ~ label", labeller="label_both")+\
p9.scale_x_continuous(
breaks=np.arange(-5, 7, 0.5 if x is "pred_prob1" else 2))+\
p9.theme_bw()+\
p9.theme(subplots_adjust={'right': 0.7, "bottom":0.2})+\
p9.theme(figure_size=(4.5,2))+\
p9.geom_point(
p9.aes(
x=x,
y="loss_value",
color="loss_name",
),
data=loss_df)
show(gg_binary("predicted_score"), "binary-loss-scores")
Next we can plot the loss as a function of predicted probability of class 1,
loss_df["pred_prob1"] = 1/(1+np.exp(-loss_df.predicted_score))
show(gg_binary("pred_prob1"), "binary-loss-prob")
How to generalize this plot to more classes than two? We can plot the loss functions for three classes on the probability simplex, projected onto the cartesian plotting plane as an equilateral triangle. First we compute the vertices of the equilateral triangle,
xmax = 1.0
upper_x = xmax/2.0
ymax = np.sqrt(xmax - upper_x**2)
vertices_mat = np.array([
[xmax, 0, 1],
[upper_x, ymax, 1],
[0,0,1]
])
vertices_mat
## array([[1. , 0. , 1. ],
## [0.5 , 0.8660254, 1. ],
## [0. , 0. , 1. ]])
The first two columns of the matrix above represent the vertices of the equilateral triangle. Each point inside this triangle represents a triple of probability values on the simplex, which sum to one and are all at least zero. To plot the loss function for each probability triple in the simplex, we can make a heat map by first computing a grid of (x,y) values:
def make_grid(mat, n_grid = 200):
nrow, ncol = mat.shape
assert ncol == 2
mesh_args = mat.apply(
lambda x: np.linspace(min(x),max(x), n_grid), axis=0)
mesh_tup = np.meshgrid(*[mesh_args[x] for x in mesh_args])
mesh_vectors = [v.flatten() for v in mesh_tup]
return pd.DataFrame(dict(zip(mesh_args,mesh_vectors)))
simplex_grid = make_grid(pd.DataFrame({
"x":np.linspace(0,xmax),
"y":np.linspace(0,ymax)
}))
simplex_grid
## x y
## 0 0.000000 0.000000
## 1 0.005025 0.000000
## 2 0.010050 0.000000
## 3 0.015075 0.000000
## 4 0.020101 0.000000
## ... ... ...
## 39995 0.979899 0.866025
## 39996 0.984925 0.866025
## 39997 0.989950 0.866025
## 39998 0.994975 0.866025
## 39999 1.000000 0.866025
##
## [40000 rows x 2 columns]
Now which of these rows above falls within the equilateral triangle, and is therefore a valid probability triple in the simplex? We need a mapping from these (x,y) coordinates to the probability triple coordinates, which we can get by solving a linear system: vertices_mat * A = I. The matrix A is the linear transformation which converts x,y coordinates to probability triple coordinates (identity matrix I represents unit vectors which are the vertices of the simplex in probability coordinates). To solve for A we just need to matrix multiply both sides by the inverse,
to_prob_mat = np.linalg.inv(vertices_mat)
to_prob_mat
## array([[ 1. , 0. , -1. ],
## [-0.57735027, 1.15470054, -0.57735027],
## [ 0. , 0. , 1. ]])
We then can convert the x,y coordinates of the grid to probability coordinates:
simplex_grid_xy = np.column_stack(
[simplex_grid, np.repeat(1, simplex_grid.shape[0])])
simplex_grid_prob = np.matmul(simplex_grid_xy, to_prob_mat)
simplex_grid_prob
## array([[ 0. , 0. , 1. ],
## [ 0.00502513, 0. , 0.99497487],
## [ 0.01005025, 0. , 0.98994975],
## ...,
## [ 0.48994975, 1. , -0.48994975],
## [ 0.49497487, 1. , -0.49497487],
## [ 0.5 , 1. , -0.5 ]])
Then we exclude the rows with any negative probability values,
keep = simplex_grid_prob.min(axis=1) >= 0
keep_grid = pd.concat([
pd.DataFrame(simplex_grid_prob), simplex_grid
], axis=1)[keep]
keep_grid
## 0 1 2 x y
## 0 0.000000 0.000000 1.000000 0.000000 0.000000
## 1 0.005025 0.000000 0.994975 0.005025 0.000000
## 2 0.010050 0.000000 0.989950 0.010050 0.000000
## 3 0.015075 0.000000 0.984925 0.015075 0.000000
## 4 0.020101 0.000000 0.979899 0.020101 0.000000
## ... ... ... ... ... ...
## 39300 0.010050 0.984925 0.005025 0.502513 0.852970
## 39301 0.015075 0.984925 0.000000 0.507538 0.852970
## 39499 0.002513 0.989950 0.007538 0.497487 0.857322
## 39500 0.007538 0.989950 0.002513 0.502513 0.857322
## 39700 0.005025 0.994975 0.000000 0.502513 0.861674
##
## [19897 rows x 5 columns]
Next we compute the logistic loss for each one of those grid points, and for each of the three labels,
def get_loss_df(loss_fun):
loss_simplex_list = []
for label in range(3):
loss_only = pd.DataFrame({
"loss":loss_fun(label),
"label":label,
})
loss_grid = pd.concat([keep_grid.reset_index(), loss_only], axis=1)
loss_simplex_list.append(loss_grid)
return pd.concat(loss_simplex_list)
def logistic_loss(label, loss_max = 5):
label_prob = keep_grid[label]
loss_vec = np.log(1/label_prob)
# threshold loss for visualization purposes, to avoid saturation.
return np.where(loss_vec<loss_max, loss_vec, loss_max)
logistic_df = get_loss_df(logistic_loss)
logistic_df
## index 0 1 2 x y loss label
## 0 0 0.000000 0.000000 1.000000 0.000000 0.000000 5.000000 0
## 1 1 0.005025 0.000000 0.994975 0.005025 0.000000 5.000000 0
## 2 2 0.010050 0.000000 0.989950 0.010050 0.000000 4.600158 0
## 3 3 0.015075 0.000000 0.984925 0.015075 0.000000 4.194693 0
## 4 4 0.020101 0.000000 0.979899 0.020101 0.000000 3.907010 0
## ... ... ... ... ... ... ... ... ...
## 19892 39300 0.010050 0.984925 0.005025 0.502513 0.852970 5.000000 2
## 19893 39301 0.015075 0.984925 0.000000 0.507538 0.852970 5.000000 2
## 19894 39499 0.002513 0.989950 0.007538 0.497487 0.857322 4.887840 2
## 19895 39500 0.007538 0.989950 0.002513 0.502513 0.857322 5.000000 2
## 19896 39700 0.005025 0.994975 0.000000 0.502513 0.861674 5.000000 2
##
## [59691 rows x 8 columns]
Finally, we plot the loss values as a heatmap on the equilateral triangle which represents the probability simplex,
def gg_loss(loss_name, grid_loss_df, breaks=[0,0.5,1]):
return p9.ggplot()+\
p9.ggtitle(loss_name+" loss on 3-simplex")+\
p9.facet_grid(". ~ label", labeller="label_both")+\
p9.geom_tile(
p9.aes(
x="x",
y="y",
fill="loss",
),
data=grid_loss_df)+\
p9.scale_fill_gradient(
low="white",
high="red")+\
p9.scale_x_continuous(
breaks=breaks)+\
p9.scale_y_continuous(
breaks=breaks)+\
p9.coord_equal()
show(gg_loss("Logistic", logistic_df), "multi-logistic")
Next, we can compute the zero-one loss in a similar fashion,
def zero_one_loss(label):
loss_vec = np.array(keep_grid.loc[:,[0,1,2]]).argmax(axis=1) != label
return np.where(loss_vec, 1, 0)
zero_one_df = get_loss_df(zero_one_loss)
zero_one_df
## index 0 1 2 x y loss label
## 0 0 0.000000 0.000000 1.000000 0.000000 0.000000 1 0
## 1 1 0.005025 0.000000 0.994975 0.005025 0.000000 1 0
## 2 2 0.010050 0.000000 0.989950 0.010050 0.000000 1 0
## 3 3 0.015075 0.000000 0.984925 0.015075 0.000000 1 0
## 4 4 0.020101 0.000000 0.979899 0.020101 0.000000 1 0
## ... ... ... ... ... ... ... ... ...
## 19892 39300 0.010050 0.984925 0.005025 0.502513 0.852970 1 2
## 19893 39301 0.015075 0.984925 0.000000 0.507538 0.852970 1 2
## 19894 39499 0.002513 0.989950 0.007538 0.497487 0.857322 1 2
## 19895 39500 0.007538 0.989950 0.002513 0.502513 0.857322 1 2
## 19896 39700 0.005025 0.994975 0.000000 0.502513 0.861674 1 2
##
## [59691 rows x 8 columns]
show(gg_loss("Zero-one", zero_one_df), "multi-zero-one")
References:
- https://en.wikipedia.org/wiki/Simplex#The_standard_simplex