Stratified batch sampler for mlr3torch
The goal of this post is to show how to use a custom torch sampler with mlr3torch, in order to use stratified sampling, which can ensure that each batch in gradient descent has a minimum number of samples from each class. This is a continuation of a previous post on the same subject.
Motivation: imbalanced classification
We consider imbalanced classification problems, which occur frequently in many different areas. For example, let us consider the sonar data set.
sonar_task <- mlr3::tsk("sonar")
(count_tab <- table(sonar_task$data(sonar_task$row_ids, "Class")$Class))
##
## M R
## 111 97
The sonar label count table above is relatively balanced. Below we compute the frequencies.
count_tab/sum(count_tab)
##
## M R
## 0.5336538 0.4663462
We typically measure class imbalance by using the minority class proportion, which is 46.6% in this case. This is not much imbalance, because perfectly balanced data would be 50% minority class. We could consider an even greater imbalance by sub-sampling, which has been discussed in two previous posts:
- Creating large imbalanced data benchmarks, Tutorial with OpenML/higgs data.
- Creating imbalanced data benchmarks, Tutorial with MNIST.
With sonar we can do:
sonar_task$filter(208:86)
(count_tab <- table(sonar_task$data(sonar_task$row_ids, "Class")$Class))
##
## M R
## 111 12
count_tab/sum(count_tab)
##
## M R
## 0.90243902 0.09756098
The output above shows 10% minority class, which is substantially more imbalance.
In mlr3
the column role stratum
can be used to ensure that sampling is proportion to the class imbalance, so that each sub-sample also has 10% minority class. To do that we use the code below:
sonar_task$col_roles$stratum <- "Class"
Custom torch batch sampler
In mlr3
we have a task
, and in torch
the analogous concept is dataset
.
We access the dataset
via a dataloader
, which is the torch
concept for dividing subtrain set samples into batches for gradient descent.
Each batch of samples is used to compute a gradient, and then update the model parameters.
For certain loss functions, we need at least two classes in the batch to get a non-zero gradient.
An example is the Area Under the Minimum (AUM) of False Positives and False Negatives, which we recently proposed for ROC curve optimization, in Journal of Machine Learning Research 2023.
We would like to use the stratum
of the task
in the context of the dataloader
.
To ensure that each batch has at least one of the minority class labels, we can define a batch sampler, as below.
library(data.table)
batch_sampler_stratified <- function(min_samples_per_stratum, shuffle=TRUE){
torch::sampler(
"StratifiedSampler",
initialize = function(data_source) {
self$data_source <- data_source
TSK <- data_source$task
self$stratum <- TSK$col_roles$stratum
if(length(self$stratum)==0)stop(TSK$id, "task missing stratum column role")
self$stratum_dt <- data.table(
TSK$data(cols=self$stratum),
row.id=1:TSK$nrow)
self$set_batch_list()
},
set_batch_list = function() {
get_indices <- if(shuffle){
function(n)torch::as_array(torch::torch_randperm(n))+1L
}else{
function(n)1:n
}
index_dt <- self$stratum_dt[
get_indices(.N)
][
, i.in.stratum := 1:.N, by=c(self$stratum)
][]
count_dt <- index_dt[, .(
max.i=max(i.in.stratum)
), by=c(self$stratum)][order(max.i)]
count_min <- count_dt$max.i[1]
num_batches <- max(1, count_min %/% min_samples_per_stratum)
max_samp <- num_batches * min_samples_per_stratum
index_dt[
, n.samp := i.in.stratum/max(i.in.stratum)*max_samp
, by=c(self$stratum)
][
, batch.i := ceiling(n.samp/min_samples_per_stratum)
][]
save_list$count[[paste("epoch", length(save_list$count)+1)]] <<- dcast(
index_dt,
batch.i ~ Class,
list(length, indices=function(x)paste(x, collapse=",")),
value.var="row.id"
)[
, labels := index_dt[, paste(Class, collapse=""), by=batch.i][order(batch.i), V1]
]
self$batch_list <- split(index_dt$row.id, index_dt$batch.i)
self$batch_sizes <- sapply(self$batch_list, length)
self$batch_size_tab <- sort(table(self$batch_sizes))
self$batch_size <- as.integer(names(self$batch_size_tab)[length(self$batch_size_tab)])
},
.iter = function() {
batch.i <- 0
function() {
if (batch.i < length(self$batch_list)) {
batch.i <<- batch.i + 1L
indices <- self$batch_list[[batch.i]]
save_list$indices[[length(save_list$indices)+1]] <<- indices
if (batch.i == length(self$batch_list)) {
self$set_batch_list()
}
return(indices)
}
coro::exhausted()
}
},
.length = function() {
length(self$batch_list)
}
)
}
Note in the definition above that
initialize
derives the stratification from thestratum
role defined in the task.set_batch_list
setsself$batch_list
which is a list with one element for each batch, each element is an integer vector of indices.- If
shuffle=FALSE
, then samples are seen in a deterministic order that matches the source data set. - If
shuffle=TRUE
, then samples are seen in a random order, and this order is different in each epoch becauseset_batch_list
is called to set a newself$batch_list
after each epoch is complete.
We can use the code above to create the batch_sampler
parameter to our learner.
Learner
Before defining the learner, we need to define its loss function.
In binary classification, the typical loss function is the logistic loss, torch::bce_with_logits_loss
.
The module below has a special forward method that saves the targets (so we can see if the stratification worked), and then uses that loss function.
nn_bce_with_logits_loss_save <- torch::nn_module(
"nn_print_loss",
inherit = torch::nn_mse_loss,
initialize = function() {
super$initialize()
self$bce <- torch::nn_bce_with_logits_loss()
},
forward = function(input, target) {
save_list$target[[length(save_list$target)+1]] <<- target
self$bce(input, target)
}
)
Then we create a new MLP learner, which by default is a linear model.
set.seed(4)
mlp_learner <- mlr3torch::LearnerTorchMLP$new(
task_type="classif",
loss=nn_bce_with_logits_loss_save)
mlp_learner$predict_type <- "prob"
Then we set several learner parameters in the code below.
mlp_learner$param_set$set_values(
epochs=1,
p=0, # dropout probability.
batch_size=1, # ignored.
batch_sampler=batch_sampler_stratified(min_samples_per_stratum = 1))
In the code above we set parameters:
epochs=1
for one epoch of learning.p=0
for no dropout regularization.batch_size=1
to avoid the error that this parameter is required, but it actually is ignored because we also specify abatch_sampler
. This a bug which should be fixed by this PR.batch_sampler
is set to our proposed stratified sampler, with min 1 sample per stratum.
In the code below we first initialize save_list
, then we train:
save_list <- list()
mlp_learner$train(sonar_task)
names(save_list)
## [1] "count" "indices" "target"
The output above indicates that some data were created in save_list
during training.
We can look at count
to see how many labels were in each batch:
save_list$count
## $`epoch 1`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R labels
## <num> <int> <int> <char> <char> <char>
## 1: 1 9 1 105,98,9,41,101,60,18,109,43 119 MMRMMMMMMM
## 2: 2 9 1 53,62,96,40,30,104,44,27,83 113 MRMMMMMMMM
## 3: 3 9 1 75,14,16,26,19,110,42,93,107 115 RMMMMMMMMM
## 4: 4 10 1 6,79,46,37,21,66,58,54,100,39 114 RMMMMMMMMMM
## 5: 5 9 1 69,51,31,65,36,81,59,28,95 116 MMMMMMMMMR
## 6: 6 9 1 34,50,35,73,13,77,90,80,94 122 MMMMMMMMMR
## 7: 7 9 1 92,29,102,17,33,91,45,7,64 112 MMMMMMMMMR
## 8: 8 10 1 85,32,8,38,48,3,61,4,106,57 121 MMMMMMMMMMR
## 9: 9 9 1 63,68,10,74,2,22,70,52,72 123 MMMMMMMRMM
## 10: 10 9 1 78,97,86,55,47,12,76,24,108 120 MMMMMMMMMR
## 11: 11 9 1 23,5,89,25,84,56,67,15,49 118 MMMMMMMMMR
## 12: 12 10 1 20,103,87,88,111,1,99,71,82,11 117 MMMMMMMMMRM
##
## $`epoch 2`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R labels
## <num> <int> <int> <char> <char> <char>
## 1: 1 9 1 30,87,60,2,101,38,48,5,34 120 MMMMRMMMMM
## 2: 2 9 1 36,7,82,80,53,58,16,110,79 112 MMMMMMRMMM
## 3: 3 9 1 35,103,76,70,97,46,14,89,47 116 RMMMMMMMMM
## 4: 4 10 1 40,85,23,94,56,62,100,65,95,68 113 RMMMMMMMMMM
## 5: 5 9 1 93,74,73,25,1,29,105,63,75 118 MMMMMMMMMR
## 6: 6 9 1 107,21,9,44,91,45,92,6,52 123 MMMMMMMMMR
## 7: 7 9 1 98,24,55,88,108,99,106,66,10 117 MMMMMMMMMR
## 8: 8 10 1 51,109,11,59,64,49,39,86,41,19 119 MMMMMMMMMRM
## 9: 9 9 1 71,54,3,32,78,67,31,37,43 114 MMMMMMMMMR
## 10: 10 9 1 77,4,81,96,50,61,111,69,84 121 MMMMMRMMMM
## 11: 11 9 1 102,18,42,26,104,13,27,90,8 115 MMMMMMMMMR
## 12: 12 10 1 33,83,22,20,57,17,15,12,28,72 122 MMMMMMMRMMM
The output above comes from set_batch_list
, and shows
- there are two tables printed, one for the first epoch, and one for the second (computed but not used yet).
- each row represents a batch.
- in each table, the
row.id_length_*
columns show the number of positive and negative labels in a batch. - the number of minority class samples (R) is always 1, because we specified
min_samples_per_stratum = 1
in the call tobatch_sampler_stratified
in the code above. - the first batch in the first table has the same label counts as the first batch in the second table, etc.
- the first batch
row.id_indices_M
in the first table are different from the corresponding indices in the second table. - so each epoch uses the samples in a different order, but with the same label counts in each batch.
Double-check correct indices and labels
The code below checks that the targets (=labels) seen by the loss function in each batch are consistent with the indices.
sonar_dataset <- mlp_learner$dataset(sonar_task)
for(batch.i in 1:length(save_list$target)){
index_vec <- save_list$indices[[batch.i]]
target_tensor <- save_list$target[[batch.i]]
target_vec <- torch::as_array(target_tensor)
label_vec <- names(count_tab)[ifelse(target_vec==1, 1, 2)]
set(
save_list$count[["epoch 1"]],
i=batch.i,
j="check",
value=paste(label_vec, collapse=""))
stopifnot(all.equal(
target_tensor$flatten(),
sonar_dataset$.getbatch(index_vec)$y$flatten()
))
}
save_list$count[["epoch 1"]]
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R labels check
## <num> <int> <int> <char> <char> <char> <char>
## 1: 1 9 1 105,98,9,41,101,60,18,109,43 119 MMRMMMMMMM MMRMMMMMMM
## 2: 2 9 1 53,62,96,40,30,104,44,27,83 113 MRMMMMMMMM MRMMMMMMMM
## 3: 3 9 1 75,14,16,26,19,110,42,93,107 115 RMMMMMMMMM RMMMMMMMMM
## 4: 4 10 1 6,79,46,37,21,66,58,54,100,39 114 RMMMMMMMMMM RMMMMMMMMMM
## 5: 5 9 1 69,51,31,65,36,81,59,28,95 116 MMMMMMMMMR MMMMMMMMMR
## 6: 6 9 1 34,50,35,73,13,77,90,80,94 122 MMMMMMMMMR MMMMMMMMMR
## 7: 7 9 1 92,29,102,17,33,91,45,7,64 112 MMMMMMMMMR MMMMMMMMMR
## 8: 8 10 1 85,32,8,38,48,3,61,4,106,57 121 MMMMMMMMMMR MMMMMMMMMMR
## 9: 9 9 1 63,68,10,74,2,22,70,52,72 123 MMMMMMMRMM MMMMMMMRMM
## 10: 10 9 1 78,97,86,55,47,12,76,24,108 120 MMMMMMMMMR MMMMMMMMMR
## 11: 11 9 1 23,5,89,25,84,56,67,15,49 118 MMMMMMMMMR MMMMMMMMMR
## 12: 12 10 1 20,103,87,88,111,1,99,71,82,11 117 MMMMMMMMMRM MMMMMMMMMRM
The output above has a new check
column which is consistent with the previous labels
column, as expected.
Both show the vector of labels in each batch.
Varying batch size
In our proposed batch sampler, we are not able to directly control batch size.
The input parameter is min_samples_per_stratum
, which affects the batch size.
In the previous section, we enforced min 1 sample per stratum (in each batch), which made batches of size 10 or 11.
Below we study how this parameter affects batch size.
label_count_dt_list <- list()
for(min_samples_per_stratum in c(1:7,20)){
mlp_learner$param_set$set_values(
epochs=1,
p=0, # dropout probability.
batch_size=1, # ignored.
batch_sampler=batch_sampler_stratified(min_samples_per_stratum = min_samples_per_stratum))
save_list <- list()
mlp_learner$train(sonar_task)
label_count_dt_list[[paste0(
"min_samples_per_stratum=", min_samples_per_stratum
)]] <- save_list$count[["epoch 1"]][, data.table(
batch.i, row.id_length_M, row.id_length_R, batch_size=row.id_length_M+row.id_length_R)]
}
label_count_dt_list
## $`min_samples_per_stratum=1`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 9 1 10
## 2: 2 9 1 10
## 3: 3 9 1 10
## 4: 4 10 1 11
## 5: 5 9 1 10
## 6: 6 9 1 10
## 7: 7 9 1 10
## 8: 8 10 1 11
## 9: 9 9 1 10
## 10: 10 9 1 10
## 11: 11 9 1 10
## 12: 12 10 1 11
##
## $`min_samples_per_stratum=2`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 18 2 20
## 2: 2 19 2 21
## 3: 3 18 2 20
## 4: 4 19 2 21
## 5: 5 18 2 20
## 6: 6 19 2 21
##
## $`min_samples_per_stratum=3`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 27 3 30
## 2: 2 28 3 31
## 3: 3 28 3 31
## 4: 4 28 3 31
##
## $`min_samples_per_stratum=4`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 37 4 41
## 2: 2 37 4 41
## 3: 3 37 4 41
##
## $`min_samples_per_stratum=5`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 55 6 61
## 2: 2 56 6 62
##
## $`min_samples_per_stratum=6`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 55 6 61
## 2: 2 56 6 62
##
## $`min_samples_per_stratum=7`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 111 12 123
##
## $`min_samples_per_stratum=20`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R batch_size
## <num> <int> <int> <int>
## 1: 1 111 12 123
The output above shows that the batch size is always about 10x the value of min_samples_per_stratum
, because the data set had about 10% minority class labels.
When min_samples_per_stratum
is 7 or larger, we get a single batch with all samples (same as full gradient method).
shuffle
parameter
The shuffle
argument to batch_sampler_stratified
controls whether or not the data are seen in a random order in each epoch.
We can see the effect of this parameter via the code below.
shuffle_list <- list()
for(shuffle in c(TRUE, FALSE)){
mlp_learner$param_set$set_values(
epochs=1,
p=0, # dropout probability.
batch_size=1, # ignored.
batch_sampler=batch_sampler_stratified(min_samples_per_stratum = 1, shuffle = shuffle))
save_list <- list()
mlp_learner$train(sonar_task)
shuffle_list[[paste0("shuffle=", shuffle)]] <- save_list$count
}
shuffle_list
## $`shuffle=TRUE`
## $`shuffle=TRUE`$`epoch 1`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R labels
## <num> <int> <int> <char> <char> <char>
## 1: 1 9 1 43,39,2,14,57,93,89,41,9 112 MMMMMMMMMR
## 2: 2 9 1 44,31,109,18,68,55,5,12,110 116 MMMMMMMMMR
## 3: 3 9 1 75,42,77,50,96,108,103,10,98 115 MMMMMMRMMM
## 4: 4 10 1 30,3,51,53,72,8,67,33,107,27 123 RMMMMMMMMMM
## 5: 5 9 1 73,83,34,102,52,63,99,49,29 118 RMMMMMMMMM
## 6: 6 9 1 79,105,65,88,101,47,37,19,62 120 RMMMMMMMMM
## 7: 7 9 1 111,60,66,80,84,97,45,100,90 122 RMMMMMMMMM
## 8: 8 10 1 15,69,17,46,6,11,32,40,86,13 121 RMMMMMMMMMM
## 9: 9 9 1 25,28,95,22,23,78,87,106,71 114 RMMMMMMMMM
## 10: 10 9 1 81,16,20,38,82,91,36,70,76 119 RMMMMMMMMM
## 11: 11 9 1 59,21,85,7,61,48,24,4,35 113 RMMMMMMMMM
## 12: 12 10 1 26,58,104,64,94,92,56,1,74,54 117 RMMMMMMMMMM
##
## $`shuffle=TRUE`$`epoch 2`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R labels
## <num> <int> <int> <char> <char> <char>
## 1: 1 9 1 29,21,102,43,18,16,33,44,10 123 RMMMMMMMMM
## 2: 2 9 1 105,14,90,101,41,66,109,108,70 121 RMMMMMMMMM
## 3: 3 9 1 94,111,34,110,52,62,74,64,39 113 MMRMMMMMMM
## 4: 4 10 1 57,75,72,81,37,79,103,99,47,45 119 RMMMMMMMMMM
## 5: 5 9 1 4,6,76,31,59,65,89,95,27 120 RMMMMMMMMM
## 6: 6 9 1 93,5,35,55,46,8,92,85,11 117 MMMMMMMRMM
## 7: 7 9 1 40,63,20,104,24,1,61,13,9 118 MMMMMMMMMR
## 8: 8 10 1 17,60,48,96,83,97,68,23,71,54 122 MMMMMMMMMMR
## 9: 9 9 1 78,56,53,49,88,12,26,84,28 116 MMMMMMMMMR
## 10: 10 9 1 42,51,98,19,77,7,36,58,32 115 MMMMMMMRMM
## 11: 11 9 1 22,15,67,3,100,38,107,91,86 112 MRMMMMMMMM
## 12: 12 10 1 2,106,80,69,25,87,82,30,50,73 114 MMMMMMMRMMM
##
##
## $`shuffle=FALSE`
## $`shuffle=FALSE`$`epoch 1`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R labels
## <num> <int> <int> <char> <char> <char>
## 1: 1 9 1 1,2,3,4,5,6,7,8,9 112 MMMMMMMMMR
## 2: 2 9 1 10,11,12,13,14,15,16,17,18 113 MMMMMMMMMR
## 3: 3 9 1 19,20,21,22,23,24,25,26,27 114 MMMMMMMMMR
## 4: 4 10 1 28,29,30,31,32,33,34,35,36,37 115 MMMMMMMMMMR
## 5: 5 9 1 38,39,40,41,42,43,44,45,46 116 MMMMMMMMMR
## 6: 6 9 1 47,48,49,50,51,52,53,54,55 117 MMMMMMMMMR
## 7: 7 9 1 56,57,58,59,60,61,62,63,64 118 MMMMMMMMMR
## 8: 8 10 1 65,66,67,68,69,70,71,72,73,74 119 MMMMMMMMMMR
## 9: 9 9 1 75,76,77,78,79,80,81,82,83 120 MMMMMMMMMR
## 10: 10 9 1 84,85,86,87,88,89,90,91,92 121 MMMMMMMMMR
## 11: 11 9 1 93,94,95,96,97,98,99,100,101 122 MMMMMMMMMR
## 12: 12 10 1 102,103,104,105,106,107,108,109,110,111 123 MMMMMMMMMMR
##
## $`shuffle=FALSE`$`epoch 2`
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R labels
## <num> <int> <int> <char> <char> <char>
## 1: 1 9 1 1,2,3,4,5,6,7,8,9 112 MMMMMMMMMR
## 2: 2 9 1 10,11,12,13,14,15,16,17,18 113 MMMMMMMMMR
## 3: 3 9 1 19,20,21,22,23,24,25,26,27 114 MMMMMMMMMR
## 4: 4 10 1 28,29,30,31,32,33,34,35,36,37 115 MMMMMMMMMMR
## 5: 5 9 1 38,39,40,41,42,43,44,45,46 116 MMMMMMMMMR
## 6: 6 9 1 47,48,49,50,51,52,53,54,55 117 MMMMMMMMMR
## 7: 7 9 1 56,57,58,59,60,61,62,63,64 118 MMMMMMMMMR
## 8: 8 10 1 65,66,67,68,69,70,71,72,73,74 119 MMMMMMMMMMR
## 9: 9 9 1 75,76,77,78,79,80,81,82,83 120 MMMMMMMMMR
## 10: 10 9 1 84,85,86,87,88,89,90,91,92 121 MMMMMMMMMR
## 11: 11 9 1 93,94,95,96,97,98,99,100,101 122 MMMMMMMMMR
## 12: 12 10 1 102,103,104,105,106,107,108,109,110,111 123 MMMMMMMMMMR
The output above shows that
- when
shuffle=TRUE
, batch 1 has different indices in epoch 1 than in epoch 2 (and similarly for other batches). - when
shuffle=FALSE
, batch 1 has same indices in epoch 1 and 2 (and they are the first few indices of each class).
These results indicate that the shuffle
parameter of the proposed sampler behaves as expected (consistent with standard batching in torch).
Randomness control, part 1: random weights
A torch model is initialized with random weights.
This is actually pseudo-randomness which can be controlled by the torch random seed.
By default mlr3torch
learners will always use a different random intiailization.
This can be controlled via the seed
parameter, as demonstrated by the code below,
param_list <- list()
set.seed(5) # controls what is used as the torch random seed (if param not set).
for(set_seed in c(0:1)){
for(rep_i in 1:2){
L <- mlr3torch::LearnerTorchMLP$new(task_type="classif")
L$param_set$set_values(epochs=0, batch_size=10)
if(set_seed)L$param_set$values$seed <- 1 # torch random seed.
L$train(sonar_task)
param_list[[sprintf("set_seed=%d rep=%d", set_seed, rep_i)]] <- unlist(lapply(
L$model$network$parameters, torch::as_array))
}
}
The code above uses 0 epochs of training, so the weights simply come from the random initialization.
The first for loop is over set_seed
values: 0 means to take the default (torch random seed depends on R random seed), 1 means to set the torch random seed to 1.
The second for loop is over rep_i
values, which are simply repetitions, so we can see if we get the same weights after running the code twice.
Below we see the first few weights using each of the methods:
do.call(rbind, param_list)[, 1:5]
## 0.weight1 0.weight2 0.weight3 0.weight4 0.weight5
## set_seed=0 rep=1 0.03279998 -0.06481405 0.07047248 0.08959757 -0.09685164
## set_seed=0 rep=2 -0.01949527 0.07055710 0.02056149 0.09997202 0.03222404
## set_seed=1 rep=1 0.06652019 -0.05698169 -0.02502741 0.06059527 -0.12153898
## set_seed=1 rep=2 0.06652019 -0.05698169 -0.02502741 0.06059527 -0.12153898
We see above that
- the first two rows have different weight values, which indicates that the default is to use different random weights in each initialization.
- the last two rows have the same weight values, which indicates that the
seed
parameter controls the random weight initialization.
Randomness control, part 2: batching
What happens when we run gradient descent?
By default we go through batches in a random order, which induces another kind of randomness (other than the weight initialization), that is also controlled by the seed
parameter, as shown by the experiment below.
param_list <- list()
set.seed(5) # controls what is used as the torch random seed (if param not set).
for(set_batch_sampler in c(0:1)){
for(rep_i in 1:2){
L <- mlr3torch::LearnerTorchMLP$new(task_type="classif")
L$param_set$set_values(epochs=1, batch_size=10, seed=1)
if(set_batch_sampler)L$param_set$values$batch_sampler <- batch_sampler_stratified(1)
L$train(sonar_task)
param_list[[sprintf("set_batch_sampler=%d rep=%d", set_batch_sampler, rep_i)]] <- unlist(lapply(
L$model$network$parameters, torch::as_array))
}
}
The code above replaces set_seed
(now always 1) with set_batch_sampler
in the first for loop.
When the batch sampler is not set, we use the usual batching mechanism, with batch_size=10
.
When the batch sampler is set, we use our custom sampler defined above.
The other difference is that we have change epochs
from 0 to 1, so that we are using gradient descent to get weight values (in addition to the random initialization).
do.call(rbind, param_list)[, 1:5]
## 0.weight1 0.weight2 0.weight3 0.weight4 0.weight5
## set_batch_sampler=0 rep=1 0.07897462 -0.04479400 -0.01260542 0.07326440 -0.1090717
## set_batch_sampler=0 rep=2 0.07897462 -0.04479400 -0.01260542 0.07326440 -0.1090717
## set_batch_sampler=1 rep=1 0.14044657 -0.06968932 0.08062322 -0.05977938 -0.1097272
## set_batch_sampler=1 rep=2 0.14044657 -0.06968932 0.08062322 -0.05977938 -0.1097272
The output above shows that the first two rows are identical, as are the last two rows.
This result indicates that setting the seed
parameter is sufficient to control randomness of both initialization and batching (using both the usual sampler, and our proposed sampler).
Note that this would not have been possible if we used R’s randomness in batch_sampler_stratified
(exercise for the reader: replace torch::torch_randperm
with sample
and show that you get different weights between repetitions).
Conclusions
We have explained how to create a custom stratified sampler for use in the mlr3torch
framework. This will be useful in experiments with loss functions that require a minimal number of samples of each class to get a non-zero gradient.
For practical applications of this sampler, rather than using the code on this page (which has un-necessary save_list
instrumentation), please check out library(mlr3torchAUM)
on github.
The version of the function in the package has the instrumentation removed.
remotes::install_github("tdhock/mlr3torchAUM")
## Using github PAT from envvar GITHUB_PAT. Use `gitcreds::gitcreds_set()` and unset GITHUB_PAT in .Renviron (or elsewhere) if you want to use the more secure git credential store instead.
## Skipping install of 'mlr3torchAUM' from a github remote, the SHA1 (9aa3447d) has not changed since last install.
## Use `force = TRUE` to force installation
mlr3torchAUM::batch_sampler_stratified
## function (min_samples_per_stratum, shuffle = TRUE)
## {
## torch::sampler("StratifiedSampler", initialize = function(data_source) {
## self$data_source <- data_source
## TSK <- data_source$task
## self$stratum <- TSK$col_roles$stratum
## if (length(self$stratum) == 0)
## stop(TSK$id, "task missing stratum column role")
## self$stratum_dt <- data.table(TSK$data(cols = self$stratum),
## row.id = 1:TSK$nrow)
## self$set_batch_list()
## }, set_batch_list = function() {
## get_indices <- if (shuffle) {
## function(n) torch::as_array(torch::torch_randperm(n)) +
## 1L
## }
## else {
## function(n) 1:n
## }
## index_dt <- self$stratum_dt[get_indices(.N)][, `:=`(i.in.stratum,
## 1:.N), by = c(self$stratum)][]
## count_dt <- index_dt[, .(max.i = max(i.in.stratum)),
## by = c(self$stratum)][order(max.i)]
## count_min <- count_dt$max.i[1]
## num_batches <- max(1, count_min%/%min_samples_per_stratum)
## max_samp <- num_batches * min_samples_per_stratum
## index_dt[, `:=`(n.samp, i.in.stratum/max(i.in.stratum) *
## max_samp), by = c(self$stratum)][, `:=`(batch.i,
## ceiling(n.samp/min_samples_per_stratum))][]
## self$batch_list <- split(index_dt$row.id, index_dt$batch.i)
## self$batch_sizes <- sapply(self$batch_list, length)
## self$batch_size_tab <- sort(table(self$batch_sizes))
## self$batch_size <- as.integer(names(self$batch_size_tab)[length(self$batch_size_tab)])
## }, .iter = function() {
## batch.i <- 0
## function() {
## if (batch.i < length(self$batch_list)) {
## batch.i <<- batch.i + 1L
## indices <- self$batch_list[[batch.i]]
## if (batch.i == length(self$batch_list)) {
## self$set_batch_list()
## }
## return(indices)
## }
## coro::exhausted()
## }
## }, .length = function() {
## length(self$batch_list)
## })
## }
## <bytecode: 0x5aed17c47ff8>
## <environment: namespace:mlr3torchAUM>
Session info
sessionInfo()
## R version 4.5.1 (2025-06-13)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.3 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.12.0
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0 LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=fr_FR.UTF-8 LC_NUMERIC=C LC_TIME=fr_FR.UTF-8 LC_COLLATE=fr_FR.UTF-8 LC_MONETARY=fr_FR.UTF-8
## [6] LC_MESSAGES=fr_FR.UTF-8 LC_PAPER=fr_FR.UTF-8 LC_NAME=C LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=fr_FR.UTF-8 LC_IDENTIFICATION=C
##
## time zone: America/Toronto
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats graphics grDevices datasets utils methods base
##
## other attached packages:
## [1] data.table_1.17.8
##
## loaded via a namespace (and not attached):
## [1] bit_4.6.0 compiler_4.5.1 crayon_1.5.3 Rcpp_1.1.0 bspm_0.5.7 parallel_4.5.1
## [7] callr_3.7.6 globals_0.18.0 uuid_1.2-1 R6_2.6.1 curl_7.0.0 knitr_1.50
## [13] palmerpenguins_0.1.1 backports_1.5.0 checkmate_2.3.3 future_1.67.0 desc_1.4.3 paradox_1.0.1
## [19] rlang_1.1.6 mlr3torch_0.3.1 lgr_0.5.0 xfun_0.53 mlr3_1.1.0 mlr3misc_0.18.0
## [25] bit64_4.6.0-1 cli_3.6.5 withr_3.0.2 magrittr_2.0.3 ps_1.9.1 digest_0.6.37
## [31] processx_3.8.6 remotes_2.5.0 torch_0.16.0 mlr3torchAUM_2025.9.8 mlr3pipelines_0.9.0 coro_1.1.0
## [37] evaluate_1.0.5 listenv_0.9.1 codetools_0.2-20 pkgbuild_1.4.8 parallelly_1.45.1 tools_4.5.1