The goal of this post is to show how to use a custom torch sampler with mlr3torch, in order to use stratified sampling, which can ensure that each batch in gradient descent has a minimum number of samples from each class.

Motivation: imbalanced classification

We consider imbalanced classification problems, which occur frequently in many different areas. For example, in a recent project involving predicting childhood autism, we used data from the National Survey of Children’s Health (NSCH), which had about 3% autism, and 20K rows.

library(data.table)
## data.table 1.17.8 using 3 threads (see ?getDTthreads).  Latest news: r-datatable.com
prop.pos <- 0.03
Nrow <- 20000
(aut_sim <- data.table(autism=rep(c(1,0), c(prop.pos, 1-prop.pos)*Nrow)))
##        autism
##         <num>
##     1:      1
##     2:      1
##     3:      1
##     4:      1
##     5:      1
##    ---       
## 19996:      0
## 19997:      0
## 19998:      0
## 19999:      0
## 20000:      0

To learn with these data in torch, we can use stochastic gradient descent. To do that, we need to wrap the data table in a dataset as below:

ds_gen <- torch::dataset(
  initialize=function(){},
  .getbatch=function(i)aut_sim[i]$autism,
  .length=function()nrow(aut_sim))

After that, we need to attach the dataset to a dataloader with a certain batch size, which we define as 100 below:

ds <- ds_gen()
batch_size <- 100
dl <- torch::dataloader(ds, batch_size=batch_size, shuffle=TRUE)

To iterate through the dataloader we use a loop, and we count the number of each class in each batch:

torch::torch_manual_seed(1)
count_dt_list <- list()
coro::loop(for (batch_tensor in dl) {
  batch_vec <- torch::as_array(batch_tensor)
  batch_id <- length(count_dt_list) + 1L
  count_dt_list[[batch_id]] <- data.table(
    batch_id,
    num_0=sum(batch_vec==0),
    num_1=sum(batch_vec==1))
})
(count_dt <- rbindlist(count_dt_list))
##      batch_id num_0 num_1
##         <int> <int> <int>
##   1:        1   100     0
##   2:        2    96     4
##   3:        3    96     4
##   4:        4    97     3
##   5:        5    98     2
##  ---                     
## 196:      196    98     2
## 197:      197    96     4
## 198:      198    98     2
## 199:      199    96     4
## 200:      200    96     4
count_dt[num_1==0]
##     batch_id num_0 num_1
##        <int> <int> <int>
##  1:        1   100     0
##  2:        6   100     0
##  3:        9   100     0
##  4:       23   100     0
##  5:       44   100     0
##  6:       67   100     0
##  7:       70   100     0
##  8:       77   100     0
##  9:       93   100     0
## 10:       98   100     0
## 11:      107   100     0

We can see above that there are 200 batches that have no positive labels. On average there should be 3 positive labels per batch, and in fact that is true:

quantile(count_dt$num_1)
##   0%  25%  50%  75% 100% 
##    0    2    3    4    9

Above we see that the 50% quantile (median) of the number of positive labels per batch is equal to 3, as expected.

With typical loss functions, such as the cross-entropy (logistic) loss, you can still do gradient descent learning with a batch of all negative examples. However that is not the case with complex loss functions like Area Under the Minimum (AUM) of False Positives and False Negatives, which we recently proposed for ROC curve optimization, in Journal of Machine Learning Research 2023. Computing the AUM requires computing a ROC curve, which is not possible without at least one positive and one negative example. So in gradient descent learning with a batch of all negative examples, we can not compute the AUM, and we must skip it (waste of time). It would be better if we could use stratified sampling, to control the number of minority class samples per batch.

Stratified sampling

One way to implement stratified sampling is via the code below. First we define a minimum number of samples per stratum:

(min_samples_per_stratum <- prop.pos*batch_size)
## [1] 3

The output above shows that there will be at least 3 samples from each stratum in each batch. Below we shuffle the data set, and count the number of samples in each stratum:

stratum <- "autism"
set.seed(1)
(shuffle_dt <- aut_sim[
, row.id := 1:.N
][sample(.N)][
, i.in.stratum := 1:.N, keyby=stratum
][])
## Key: <autism>
##        autism row.id i.in.stratum
##         <num>  <int>        <int>
##     1:      0  17401            1
##     2:      0   4775            2
##     3:      0  13218            3
##     4:      0  10539            4
##     5:      0   8462            5
##    ---                           
## 19996:      1    372          596
## 19997:      1    473          597
## 19998:      1    437          598
## 19999:      1    264          599
## 20000:      1    561          600

The output above shows two new columns:

  • row.id is the row number in the original data table.
  • i.in.stratum is the row number of the shuffled data, relative to the stratum (autism).

Next, we count the number of samples per stratum.

(count_dt <- shuffle_dt[, .(max.i=max(i.in.stratum)), by=stratum][order(max.i)])
##    autism max.i
##     <num> <int>
## 1:      1   600
## 2:      0 19400
(count_min <- count_dt$max.i[1])
## [1] 600

Above, we see the smallest stratum has 600 samples. Next, we add a column n.samp with values between 0 and 600:

shuffle_dt[
, n.samp := i.in.stratum/max(i.in.stratum)*count_min, by=stratum
][]
## Key: <autism>
##        autism row.id i.in.stratum       n.samp
##         <num>  <int>        <int>        <num>
##     1:      0  17401            1   0.03092784
##     2:      0   4775            2   0.06185567
##     3:      0  13218            3   0.09278351
##     4:      0  10539            4   0.12371134
##     5:      0   8462            5   0.15463918
##    ---                                        
## 19996:      1    372          596 596.00000000
## 19997:      1    473          597 597.00000000
## 19998:      1    437          598 598.00000000
## 19999:      1    264          599 599.00000000
## 20000:      1    561          600 600.00000000

The idea is that n.samp can be used to control the number of samples we take from the smallest stratum. If n.samp <= 1, then we take 1 sample from the smallest stratum, with a number of samples from other strata that is proportional. In other words, we can use n.samp to define batch.i, a batch number:

shuffle_dt[
, batch.i := ceiling(n.samp/min_samples_per_stratum)
][]
## Key: <autism>
##        autism row.id i.in.stratum       n.samp batch.i
##         <num>  <int>        <int>        <num>   <num>
##     1:      0  17401            1   0.03092784       1
##     2:      0   4775            2   0.06185567       1
##     3:      0  13218            3   0.09278351       1
##     4:      0  10539            4   0.12371134       1
##     5:      0   8462            5   0.15463918       1
##    ---                                                
## 19996:      1    372          596 596.00000000     199
## 19997:      1    473          597 597.00000000     199
## 19998:      1    437          598 598.00000000     200
## 19999:      1    264          599 599.00000000     200
## 20000:      1    561          600 600.00000000     200

We see from the output above that batch.i is an integer from 1 to 200, that indicates in which batch each sample appears. Below we see counts of each batch and class label.

dcast(shuffle_dt, batch.i ~ autism, length)
## Key: <batch.i>
##      batch.i     0     1
##        <num> <int> <int>
##   1:       1    97     3
##   2:       2    97     3
##   3:       3    97     3
##   4:       4    97     3
##   5:       5    97     3
##  ---                    
## 196:     196    97     3
## 197:     197    97     3
## 198:     198    97     3
## 199:     199    97     3
## 200:     200    97     3

The table above has one row per batch, and one column per class label. We see that the class counts are constant across batches, consistent with stratified random sampling.

Custom sampler

How to use the code above with mlr3torch? We need to define a sampler class, as in the code below:

hack_sampler_class <- torch::sampler(
  "HackSampler",
  initialize = function(data_source) {
    self$data_source <- data_source
  },
  .iter_batch = function(batch_size) {
    shuffle_dt <- aut_sim[
    , row.id := 1:.N
    ][sample(.N)][
    , i.in.stratum := 1:.N, keyby=stratum
    ][]
    count_dt <- shuffle_dt[, .(max.i=max(i.in.stratum)), by=stratum][order(max.i)]
    count_min <- count_dt$max.i[1]
    shuffle_dt[
    , n.samp := i.in.stratum/max(i.in.stratum)*count_min, by=stratum
    ][
    , batch.i := ceiling(n.samp/min_samples_per_stratum)
    ][]
    batch_list <- split(shuffle_dt, shuffle_dt$batch.i)
    count <- 0
    function() {
      if (count < length(batch_list)) {
        count <<- count + 1L
        return(batch_list[[count]]$row.id)
      }
      coro::exhausted()
    }
  },
  .length = function() {
    length(self$data_source)
  }
)

I call the code above a “hack” because it takes the same fixed value for min_samples_per_stratum as defined in a previous code block (TODO: make this a parameter). To use the sampler class, we must first instantiate it with a data set:

hack_sampler_instance <- hack_sampler_class(ds)

Then we specify that instance as the sampler argument of the dataloader:

hack_dl <- torch::dataloader(ds, sampler = hack_sampler_instance)

Finally we can loop over batches, to verify that the stratified sampling works.

torch::torch_manual_seed(1)
count_dt_list <- list()
coro::loop(for (batch_tensor in hack_dl) {
  batch_vec <- torch::as_array(batch_tensor)
  batch_id <- length(count_dt_list) + 1L
  count_dt_list[[batch_id]] <- data.table(
    batch_id,
    num_0=sum(batch_vec==0),
    num_1=sum(batch_vec==1))
})
(count_dt <- rbindlist(count_dt_list))
##      batch_id num_0 num_1
##         <int> <int> <int>
##   1:        1    97     3
##   2:        2    97     3
##   3:        3    97     3
##   4:        4    97     3
##   5:        5    97     3
##  ---                     
## 196:      196    97     3
## 197:      197    97     3
## 198:      198    97     3
## 199:      199    97     3
## 200:      200    97     3

Plugging into mlr3torch

TODO

Conclusions

TODO

Session info

sessionInfo()
## R version 4.5.1 (2025-06-13 ucrt)
## Platform: x86_64-w64-mingw32/x64
## Running under: Windows 11 x64 (build 22631)
## 
## Matrix products: default
##   LAPACK version 3.12.1
## 
## locale:
## [1] LC_COLLATE=English_United States.utf8  LC_CTYPE=English_United States.utf8    LC_MONETARY=English_United States.utf8
## [4] LC_NUMERIC=C                           LC_TIME=English_United States.utf8    
## 
## time zone: America/Toronto
## tzcode source: internal
## 
## attached base packages:
## [1] stats     graphics  utils     datasets  grDevices methods   base     
## 
## other attached packages:
## [1] data.table_1.17.8
## 
## loaded via a namespace (and not attached):
##  [1] coro_1.1.0     R6_2.6.1       xfun_0.52      bit_4.6.0      magrittr_2.0.3 torch_0.15.1   knitr_1.50    
##  [8] bit64_4.6.0-1  ps_1.9.1       cli_3.6.5      processx_3.8.6 callr_3.7.6    compiler_4.5.1 tools_4.5.1   
## [15] evaluate_1.0.4 Rcpp_1.1.0     rlang_1.1.6