A custom DataLoader for mlr3torch
The goal of this post is to show how to use a custom torch sampler with mlr3torch, in order to use stratified sampling, which can ensure that each batch in gradient descent has a minimum number of samples from each class.
Motivation: imbalanced classification
We consider imbalanced classification problems, which occur frequently in many different areas. For example, in a recent project involving predicting childhood autism, we used data from the National Survey of Children’s Health (NSCH), which had about 3% autism, and 20K rows (plus 50 to show how to handle a small batch at the end).
library(data.table)
prop.pos <- 0.03
Nrow <- 20050
(aut_sim <- data.table(autism=rep(c(1,0), round(c(prop.pos, 1-prop.pos)*Nrow))))
## autism
## <num>
## 1: 1
## 2: 1
## 3: 1
## 4: 1
## 5: 1
## ---
## 20046: 0
## 20047: 0
## 20048: 0
## 20049: 0
## 20050: 0
To learn with these data in torch, we can use stochastic gradient descent.
To do that, we need to wrap the data table in a dataset
as below:
ds_gen <- torch::dataset(
initialize=function(){},
.getbatch=function(i)aut_sim[i]$autism,
.length=function()nrow(aut_sim))
After that, we need to attach the dataset
to a dataloader
with a
certain batch size, which we define as 100 below:
ds <- ds_gen()
batch_size <- 100
dl <- torch::dataloader(ds, batch_size=batch_size, shuffle=TRUE)
To iterate through the dataloader
we use a loop, and we count the
number of each class in each batch:
torch::torch_manual_seed(1)
count_dt_list <- list()
coro::loop(for (batch_tensor in dl) {
batch_vec <- torch::as_array(batch_tensor)
batch_id <- length(count_dt_list) + 1L
count_dt_list[[batch_id]] <- data.table(
batch_id,
num_0=sum(batch_vec==0),
num_1=sum(batch_vec==1))
})
(count_dt <- rbindlist(count_dt_list))
## batch_id num_0 num_1
## <int> <int> <int>
## 1: 1 98 2
## 2: 2 97 3
## 3: 3 97 3
## 4: 4 97 3
## 5: 5 96 4
## ---
## 197: 197 95 5
## 198: 198 99 1
## 199: 199 99 1
## 200: 200 99 1
## 201: 201 49 1
(batches_without_positive_labels <- count_dt[num_1==0])
## batch_id num_0 num_1
## <int> <int> <int>
## 1: 25 100 0
## 2: 26 100 0
## 3: 47 100 0
## 4: 55 100 0
## 5: 57 100 0
## 6: 59 100 0
## 7: 105 100 0
## 8: 110 100 0
## 9: 161 100 0
We can see above that there are 9 batches that have no positive labels. On average there should be 3 positive labels per batch, and in fact that is true:
quantile(count_dt$num_1)
## 0% 25% 50% 75% 100%
## 0 2 3 4 8
Above we see that the 50% quantile (median) of the number of positive labels per batch is equal to 3, as expected.
With typical loss functions, such as the cross-entropy (logistic) loss, you can still do gradient descent learning with a batch of all negative examples. However that is not the case with complex loss functions like Area Under the Minimum (AUM) of False Positives and False Negatives, which we recently proposed for ROC curve optimization, in Journal of Machine Learning Research 2023. Computing the AUM requires computing a ROC curve, which is not possible without at least one positive and one negative example. So in gradient descent learning with a batch of all negative examples, we can not compute the AUM, and we must skip it (waste of time). It would be better if we could use stratified sampling, to control the number of minority class samples per batch.
Stratified sampling
One way to implement stratified sampling is via the code below. First we define a minimum number of samples per stratum:
(min_samples_per_stratum <- prop.pos*batch_size)
## [1] 3
The output above shows that there will be at least 3 samples from each stratum in each batch. Below we shuffle the data set, and count the number of samples in each stratum:
stratum <- "autism"
set.seed(1)
(shuffle_dt <- aut_sim[
, row.id := 1:.N
][sample(.N)][
, i.in.stratum := 1:.N, keyby=stratum
][])
## Key: <autism>
## autism row.id i.in.stratum
## <num> <int> <int>
## 1: 0 17401 1
## 2: 0 4775 2
## 3: 0 13218 3
## 4: 0 10539 4
## 5: 0 8462 5
## ---
## 20046: 1 425 598
## 20047: 1 589 599
## 20048: 1 437 600
## 20049: 1 225 601
## 20050: 1 33 602
The output above shows two new columns:
row.id
is the row number in the original data table.i.in.stratum
is the row number of the shuffled data, relative to the stratum (autism).
Next, we count the number of samples per stratum.
(count_dt <- shuffle_dt[, .(max.i=max(i.in.stratum)), by=stratum][order(max.i)])
## autism max.i
## <num> <int>
## 1: 1 602
## 2: 0 19448
(count_min <- count_dt$max.i[1])
## [1] 602
(num_batches <- count_min %/% min_samples_per_stratum)
## [1] 200
(max_samp <- num_batches * min_samples_per_stratum)
## [1] 600
Above, we see the smallest stratum has 602 samples.
The number of batches is num_batches
which uses integer division so that we never have fewer than min_samples_per_stratum
in any batch.
Next, we add a column n.samp
with values between 0 and max_samp
:
shuffle_dt[
, n.samp := i.in.stratum/max(i.in.stratum)*max_samp, by=stratum
][]
## Key: <autism>
## autism row.id i.in.stratum n.samp
## <num> <int> <int> <num>
## 1: 0 17401 1 0.0308515
## 2: 0 4775 2 0.0617030
## 3: 0 13218 3 0.0925545
## 4: 0 10539 4 0.1234060
## 5: 0 8462 5 0.1542575
## ---
## 20046: 1 425 598 596.0132890
## 20047: 1 589 599 597.0099668
## 20048: 1 437 600 598.0066445
## 20049: 1 225 601 599.0033223
## 20050: 1 33 602 600.0000000
The idea is that n.samp
can be used to control the number of samples we take from the smallest stratum.
If n.samp <= 1
, then we take 1 sample from the smallest stratum, with a number of samples from other strata that is proportional.
In other words, we can use n.samp
to define batch.i
, a batch number:
shuffle_dt[
, batch.i := ceiling(n.samp/min_samples_per_stratum)
][]
## Key: <autism>
## autism row.id i.in.stratum n.samp batch.i
## <num> <int> <int> <num> <num>
## 1: 0 17401 1 0.0308515 1
## 2: 0 4775 2 0.0617030 1
## 3: 0 13218 3 0.0925545 1
## 4: 0 10539 4 0.1234060 1
## 5: 0 8462 5 0.1542575 1
## ---
## 20046: 1 425 598 596.0132890 199
## 20047: 1 589 599 597.0099668 200
## 20048: 1 437 600 598.0066445 200
## 20049: 1 225 601 599.0033223 200
## 20050: 1 33 602 600.0000000 200
We see from the output above that batch.i
is an integer from 1 to 200, that indicates in which batch each sample appears.
Below we see counts of each batch and class label.
(class_counts <- dcast(shuffle_dt, batch.i ~ autism, length))
## Key: <batch.i>
## batch.i 0 1
## <num> <int> <int>
## 1: 1 97 3
## 2: 2 97 3
## 3: 3 97 3
## 4: 4 97 3
## 5: 5 98 3
## ---
## 196: 196 98 3
## 197: 197 97 3
## 198: 198 97 3
## 199: 199 97 3
## 200: 200 98 4
The table above has one row per batch, and one column per class label. We see that the class counts are mostly constant across batches, consistent with stratified random sampling.
class_counts[`1` < min_samples_per_stratum]
## Key: <batch.i>
## Empty data.table (0 rows and 3 cols): batch.i,0,1
The output above shows that there are now batches with fewer positive examples than the specified minimum.
Custom sampler
How to use the code above with torch
?
We need to define a sampler class, as in the code below:
hack_sampler_class <- torch::sampler(
"HackSampler",
initialize = function(data_source) {
shuffle_dt <- data.table(
row.id = 1:length(data_source)
)[
, autism := data_source$.getbatch(row.id)
][sample(.N)][
, i.in.stratum := 1:.N, keyby=stratum
][]
count_dt <- shuffle_dt[, .(max.i=max(i.in.stratum)), by=stratum][order(max.i)]
count_min <- count_dt$max.i[1]
num_batches <- count_min %/% min_samples_per_stratum
max_samp <- num_batches * min_samples_per_stratum
shuffle_dt[
, n.samp := i.in.stratum/max(i.in.stratum)*max_samp, by=stratum
][
, batch.i := ceiling(n.samp/min_samples_per_stratum)
][]
self$batch_list <- split(shuffle_dt, shuffle_dt$batch.i)
},
.iter = function() {
count <- 0
function() {
if (count < length(self$batch_list)) {
count <<- count + 1L
return(self$batch_list[[count]]$row.id)
}
coro::exhausted()
}
},
.length = function() {
length(self$data_source)
}
)
I call the code above a “hack” because it takes the same fixed value
for min_samples_per_stratum
as defined in a previous code block (TODO: make this a parameter).
To use the sampler class, we must first instantiate it with a data set:
hack_sampler_instance <- hack_sampler_class(ds)
Then we specify that instance as the sampler argument of the dataloader:
hack_dl <- torch::dataloader(ds, batch_sampler = hack_sampler_instance)
Finally we can loop over batches, to verify that the stratified sampling works.
torch::torch_manual_seed(1)
count_dt_list <- list()
coro::loop(for (batch_tensor in hack_dl) {
batch_id <- length(count_dt_list) + 1L
batch_dt <- data.table(
batch_id,
class=torch::as_array(batch_tensor))
count_dt_list[[batch_id]] <- dcast(batch_dt, batch_id ~ class, length)
})
(count_dt <- rbindlist(count_dt_list))
## batch_id 0 1
## <int> <int> <int>
## 1: 1 97 3
## 2: 2 97 3
## 3: 3 97 3
## 4: 4 97 3
## 5: 5 98 3
## ---
## 196: 196 98 3
## 197: 197 97 3
## 198: 198 97 3
## 199: 199 97 3
## 200: 200 98 4
Plugging into mlr3torch
First, note that this work resulted in a doc improvement to torch and several modifications to mlr3torch:
- learner
batch_sampler
andsampler
params are now set to the class (not instance), PR. The sampler is instanteiated at the same time as the learner. - informative error added when the sampler
.length
is not consistent with the number of times.iter
can be called before returningcoro::exhausted()
, PR. - in the code below we either need to specify
batch_size
(even though it is un-ncessary), or use the fix in this other PR, which removes the error for nobatch_size
whenbatch_sampler
is specified.
We can then create a simple linear model torch learner in the mlr3torch
system, and apply it to the sonar data set, using the stratified sampling strategy.
First we create a sonar task with stratum
, which will be used for stratification in our custom sampler.
sonar_task <- mlr3::tsk("sonar")
sonar_task$col_roles$stratum <- "Class"
Then we create a new MLP learner, which by default is a linear model.
mlp_learner <- mlr3torch::LearnerTorchMLP$new(task_type="classif")
mlp_learner$predict_type <- "prob"
Then we set several learner parameters in the code below.
Also note in the stratified_sampler_class
that
initialize
derives the stratification from thestratum
role defined in the task.set_batch_list
setsself$batch_list
which is a list with one element for each batch, each element is an integer vector of indices.- Samples are seen in a random order because of
sample(.N)
and this order is different in each epoch becauseset_batch_list
is called to set a newself$batch_list
after each epoch is complete.
min_samples_per_stratum <- 10
stratified_sampler_class <- torch::sampler(
"StratifiedSampler",
initialize = function(data_source) {
self$data_source <- data_source
TSK <- data_source$task
self$stratum <- TSK$col_roles$stratum
self$stratum_dt <- data.table(
TSK$data(cols=self$stratum),
row.id=1:TSK$nrow)
self$set_batch_list()
},
set_batch_list = function() {
shuffle_dt <- self$stratum_dt[sample(.N)][
, i.in.stratum := 1:.N, by=c(self$stratum)
][]
count_dt <- shuffle_dt[, .(
max.i=max(i.in.stratum)
), by=c(self$stratum)][order(max.i)]
count_min <- count_dt$max.i[1]
num_batches <- count_min %/% min_samples_per_stratum
max_samp <- num_batches * min_samples_per_stratum
shuffle_dt[
, n.samp := i.in.stratum/max(i.in.stratum)*max_samp
, by=c(self$stratum)
][
, batch.i := ceiling(n.samp/min_samples_per_stratum)
][]
print(dcast(
shuffle_dt,
batch.i ~ Class,
list(length, indices=function(x)paste(x, collapse=",")),
value.var="row.id"))
self$batch_list <- split(shuffle_dt$row.id, shuffle_dt$batch.i)
self$batch_sizes <- sapply(self$batch_list, length)
self$batch_size_tab <- sort(table(self$batch_sizes))
self$batch_size <- as.integer(names(self$batch_size_tab)[length(self$batch_size_tab)])
},
.iter = function() {
count <- 0
function() {
if (count < length(self$batch_list)) {
count <<- count + 1L
indices <- self$batch_list[[count]]
if (count == length(self$batch_list)) {
self$set_batch_list()
}
return(indices)
}
coro::exhausted()
}
},
.length = function() {
length(self$batch_list)
}
)
mlp_learner$param_set$set_values(
epochs=1,
p=0, # dropout probability.
batch_size=1, # ignored.
batch_sampler=stratified_sampler_class)
In the code above we set parameters:
epochs=1
for one epoch of learning.p=0
for no dropout regularization.batch_size=1
to avoid the error that this parameter is required, but it actually is ignored because we also specify abatch_sampler
. This a bug which should be fixed by this PR.
In the code below we train:
mlp_learner$train(sonar_task)
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R
## <num> <int> <int> <char> <char>
## 1: 1 12 10 159,165,195,163,157,158,193,186,202,114,204,191 2,89,65,94,17,83,57,39,84,27
## 2: 2 12 11 196,177,206,140,98,153,201,149,151,127,203,126 30,66,26,72,5,59,51,71,96,52,56
## 3: 3 13 11 152,155,190,194,120,106,146,160,172,181,130,171,154 54,60,24,46,6,55,18,1,79,40,58
## 4: 4 12 11 115,131,145,188,164,109,185,100,192,118,107,147 53,68,95,8,22,73,61,62,77,97,31
## 5: 5 12 10 112,125,166,133,123,143,141,180,116,189,167,208 78,43,93,11,85,23,9,33,76,70
## 6: 6 13 11 135,110,178,199,117,128,156,150,101,138,108,168,132 36,80,91,74,28,88,50,41,12,7,35
## 7: 7 12 11 103,179,137,162,182,173,119,111,136,148,174,139 47,16,32,21,75,20,44,13,82,38,69
## 8: 8 12 11 104,105,124,175,99,169,200,207,184,122,113,187 87,10,15,49,92,67,37,3,19,63,42
## 9: 9 13 11 176,161,170,198,183,197,144,129,102,121,134,142,205 90,14,81,25,86,48,45,34,29,4,64
## Key: <batch.i>
## batch.i row.id_length_M row.id_length_R row.id_indices_M row.id_indices_R
## <num> <int> <int> <char> <char>
## 1: 1 12 10 174,118,106,172,196,125,178,103,107,132,199,207 70,16,44,86,40,75,47,6,61,78
## 2: 2 12 11 123,206,176,113,112,138,147,135,116,192,119,194 27,21,7,29,51,79,31,92,38,12,89
## 3: 3 13 11 115,162,155,167,146,166,98,202,122,161,195,149,148 90,1,53,30,58,67,52,69,57,45,84
## 4: 4 12 11 193,184,104,164,114,190,127,203,133,173,189,142 62,88,25,39,49,10,8,19,56,76,83
## 5: 5 12 10 139,121,201,102,158,204,154,171,170,180,186,137 81,94,9,37,71,68,23,97,73,93
## 6: 6 13 11 208,181,187,131,182,136,145,101,124,188,144,160,128 43,28,60,32,96,18,65,80,2,85,91
## 7: 7 12 11 117,99,120,156,126,183,111,197,134,100,191,105 55,41,11,34,26,36,50,77,20,87,82
## 8: 8 12 11 157,143,108,205,110,179,150,177,168,151,169,141 95,5,63,35,54,4,64,14,15,59,33
## 9: 9 13 11 153,140,109,200,175,198,130,185,159,129,163,152,165 24,13,48,3,46,22,72,74,17,42,66
The output above is from the print statement inside set_batch_list
, which shows
- there are two tables printed, one for the first epoch, and one for the second (not used).
- each row represents a batch.
- in each table, the
row.id_length_*
columns show the number of positive and negative labels in a batch. - the number of minority class samples (R) is always at least 10.
- the first batch in the first table has the same label counts as the first batch in the second table, etc.
- the first batch
row.id_indices_M
in the first table are different from the corresponding indices in the second table. - so each epoch uses the samples in a different order, but with the same label counts in each batch.
Conclusions
We have explained how to create a custom stratified sampler for use in the mlr3torch
framework. This will be useful in experiments with loss functions that require a minimal number of samples of each class to get a non-zero gradient.
Read the next post to see how we extended this method, exploring different min_samples_per_stratum
values, and verifying its correctness.
Session info
sessionInfo()
## R version 4.5.1 (2025-06-13)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.3 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.12.0
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0 LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=fr_FR.UTF-8 LC_NUMERIC=C LC_TIME=fr_FR.UTF-8 LC_COLLATE=fr_FR.UTF-8 LC_MONETARY=fr_FR.UTF-8
## [6] LC_MESSAGES=fr_FR.UTF-8 LC_PAPER=fr_FR.UTF-8 LC_NAME=C LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=fr_FR.UTF-8 LC_IDENTIFICATION=C
##
## time zone: America/Toronto
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats graphics grDevices datasets utils methods base
##
## other attached packages:
## [1] data.table_1.17.8
##
## loaded via a namespace (and not attached):
## [1] crayon_1.5.3 knitr_1.50 cli_3.6.5 xfun_0.53 rlang_1.1.6
## [6] processx_3.8.6 torch_0.16.0 coro_1.1.0 bit_4.6.0 mlr3pipelines_0.9.0
## [11] listenv_0.9.1 backports_1.5.0 pkgbuild_1.4.8 ps_1.9.1 paradox_1.0.1
## [16] mlr3misc_0.18.0 evaluate_1.0.5 mlr3_1.1.0 palmerpenguins_0.1.1 mlr3torch_0.3.1
## [21] compiler_4.5.1 mlr3resampling_2025.7.30 codetools_0.2-20 Rcpp_1.1.0 future_1.67.0
## [26] digest_0.6.37 R6_2.6.1 curl_7.0.0 parallelly_1.45.1 parallel_4.5.1
## [31] magrittr_2.0.3 callr_3.7.6 checkmate_2.3.3 withr_3.0.2 bit64_4.6.0-1
## [36] uuid_1.2-1 tools_4.5.1 globals_0.18.0 bspm_0.5.7 lgr_0.5.0
## [41] remotes_2.5.0 desc_1.4.3