The goal of this post is to show how to use our recently proposed AUM loss (useful for unbalanced classification problems), with the mlr3torch package in R. This post explains the code I used to prepare slides for a talk last week for Toulouse R User Group.

Intro/issue

While preparing the talk, I ran into an issue, which can be understood using the simple example code below,

library(mlr3torch)
nn_bce_loss3 = nn_module(c("nn_bce_with_logits_loss3", "nn_loss"),
  initialize = function(weight = NULL, reduction = "mean", pos_weight = NULL) {
    self$loss = nn_bce_with_logits_loss(weight, reduction, pos_weight)
  },
  forward = function(input, target) {
    self$loss(input$reshape(-1), target$to(dtype = torch_float())-1)
  }
)
loss = nn_bce_loss3()
loss(torch_randn(10, 1), torch_randint(0, 1, 10))
task = tsk("sonar")
graph = po("torch_ingress_num") %>>%
  nn("linear", out_features = 1) %>>%
  po("torch_loss", loss = nn_bce_loss3) %>>%
  po("torch_optimizer") %>>%
  po("torch_model_classif",
     epochs = 1,
     batch_size = 32,
     predict_type="prob")
glrn = as_learner(graph)
glrn$train(task)
glrn$predict(task)

The code above has predict_type="prob" and out_features=1 so I got the following error, using what was mlr3torch main branch at the time,

if(FALSE){#broke
  remotes::install_github("mlr-org/mlr3torch@6e99e02908788275622a7b723d211f357081699a")
  glrn$predict(task)
  ## Erreur dans dimnames(x) <- dn : 
  ##   la longueur de 'dimnames' [2] n'est pas égale à l'étendue du tableau
  ## This happened PipeOp torch_model_classif's $predict()
}

The error happens because the torch model outputs only one column, but some later code assumes there are two.

My PR

I hacked a solution that fixes this (see below), and I filed a PR.

if(FALSE){#fix
  remotes::install_github("tdhock/mlr3torch@69d4adda7a71c05403d561bf3bb1ffb279978d0d")
  glrn$predict(task)
  ## <PredictionClassif> for 208 observations:
  ##  row_ids truth response
  ##        1     R        M
  ##        2     R        M
  ##        3     R        M
  ##      ---   ---      ---
  ##      206     M        M
  ##      207     M        M
  ##      208     M        M
}

Newer PR

My PR was not generic enough, so Seb Fischer proposed another PR.

remotes::install_github("mlr-org/mlr3torch@c03d61a18e9785e2dbb5b20e2b6dada74a9b58b8")
## Using github PAT from envvar GITHUB_PAT. Use `gitcreds::gitcreds_set()` and unset GITHUB_PAT in .Renviron (or elsewhere) if you want to use the more secure git credential store instead.
## Skipping install of 'mlr3torch' from a github remote, the SHA1 (c03d61a1) has not changed since last install.
##   Use `force = TRUE` to force installation
stask <- mlr3::tsk("sonar")
po_list <- list(
  mlr3torch::PipeOpTorchIngressNumeric$new(),
  mlr3torch::nn("head"),
  mlr3pipelines::po(
    "torch_loss",
    loss = torch::nn_bce_with_logits_loss),
  mlr3pipelines::po("torch_optimizer"),
  mlr3pipelines::po(
    "torch_model_classif",
    epochs = 1,
    batch_size = 1000,
    predict_type="prob"))
graph <- Reduce(mlr3pipelines::concat_graphs, po_list)
glrn <- mlr3::as_learner(graph)
glrn$train(stask)
glrn$predict(stask)
## <PredictionClassif> for 208 observations:
##  row_ids truth response    prob.M    prob.R
##        1     R        M 0.5362849 0.4637151
##        2     R        M 0.5319837 0.4680163
##        3     R        M 0.5488220 0.4511780
##      ---   ---      ---       ---       ---
##      206     M        M 0.5179592 0.4820408
##      207     M        M 0.5232298 0.4767702
##      208     M        M 0.5309298 0.4690702

It looks like this PR improves the mlr3torch support for binary classification! It is important to note a few things about the implementation.

Binary labels in torch and R

First, at the R level, binary labels are represented as a factor with two levels. In the case of the sonar data, the two levels are R and M:

(Class <- stask$data()$Class)
##   [1] R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R
##  [58] R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R R M M M M M M M M M M M M M M M M M
## [115] M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M
## [172] M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M
## Levels: M R

The mlr3torch package converts this representation into a torch float tensor. We can see that by defining a custom loss function, for example my proposed AUM loss for ROC curve optimization.

Proposed_AUM <- function(pred_tensor, label_2d_tensor){
  label_tensor <- label_2d_tensor$flatten()
  is_positive = label_tensor == 1
  is_negative = label_tensor != 1
  fn_diff = torch::torch_where(is_positive, -1, 0)
  fp_diff = torch::torch_where(is_positive, 0, 1)
  thresh_tensor = -pred_tensor$flatten()
  sorted_indices = torch::torch_argsort(thresh_tensor)
  fp_denom = torch::torch_sum(is_negative) #or 1 for AUM based on count instead of rate
  fn_denom = torch::torch_sum(is_positive) #or 1 for AUM based on count instead of rate
  sorted_fp_cum = fp_diff[sorted_indices]$cumsum(dim=1)/fp_denom
  sorted_fn_cum = -fn_diff[sorted_indices]$flip(1)$cumsum(dim=1)$flip(1)/fn_denom
  sorted_thresh = thresh_tensor[sorted_indices]
  sorted_is_diff = sorted_thresh$diff() != 0
  sorted_fp_end = torch::torch_cat(c(sorted_is_diff, torch::torch_tensor(TRUE)))
  sorted_fn_end = torch::torch_cat(c(torch::torch_tensor(TRUE), sorted_is_diff))
  uniq_thresh = sorted_thresh[sorted_fp_end]
  uniq_fp_after = sorted_fp_cum[sorted_fp_end]
  uniq_fn_before = sorted_fn_cum[sorted_fn_end]
  FPR = torch::torch_cat(c(torch::torch_tensor(0.0), uniq_fp_after))
  FNR = torch::torch_cat(c(uniq_fn_before, torch::torch_tensor(0.0)))
  roc <- list(
    FPR=FPR,
    FNR=FNR,
    TPR=1 - FNR,
    "min(FPR,FNR)"=torch::torch_minimum(FPR, FNR),
    min_constant=torch::torch_cat(c(torch::torch_tensor(-Inf), uniq_thresh)),
    max_constant=torch::torch_cat(c(uniq_thresh, torch::torch_tensor(Inf))))
  min_FPR_FNR = roc[["min(FPR,FNR)"]][2:-2]
  constant_diff = roc$min_constant[2:N]$diff()
  torch::torch_sum(min_FPR_FNR * constant_diff)
}
nn_AUM_loss <- torch::nn_module(
  "nn_AUM_loss",
  inherit = torch::nn_mse_loss,
  initialize = function() {
    super$initialize()
  },
  forward = function(input, target) {
    print(input, n=5)
    print(target, n=5)
    print(table(as.integer(target)))
    Proposed_AUM(input, target)
  }
)
po_list <- list(
  mlr3torch::PipeOpTorchIngressNumeric$new(),
  mlr3torch::nn("head"),
  mlr3pipelines::po(
    "torch_loss",
    loss = nn_AUM_loss),
  mlr3pipelines::po("torch_optimizer"),
  mlr3pipelines::po(
    "torch_model_classif",
    epochs = 1,
    batch_size = 1000,
    predict_type="prob"))
graph <- Reduce(mlr3pipelines::concat_graphs, po_list)
glrn <- mlr3::as_learner(graph)
set.seed(2)#controls order of batches.
glrn$train(stask)
## torch_tensor
## -0.3983
## -0.2437
## -0.2986
## -0.3088
## -0.4530
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{208,1} ][ grad_fn = <AddmmBackward0> ]
## torch_tensor
##  0
##  1
##  1
##  0
##  1
## ... [the output was truncated (use n=-1 to disable)]
## [ CPUFloatType{208,1} ]
## 
##   0   1 
##  97 111
table(Class, as.integer(Class))
##      
## Class   1   2
##     M 111   0
##     R   0  97

We see in the table above that we have the following correspondence:

R factor level M R
R integer 1 2
torch float 1 0

So the first factor level in R is considered the positive class in torch, which has the float value 1. The negative class is the second factor level, which gets converted to the float value 0.

Conclusions

The mlr3torch now supports binary classification, with neural networks that output a scalar value (larger for more likely to be positive class).

Exercise for the reader

Now that you know how implement a custom loss function for binary classification, you can implement a benchmark_grid with a list of learners, some of which have Proposed_AUM loss, others with classic torch::nn_bce_with_logits_loss, right?

Session Info

sessionInfo()
## R version 4.5.0 (2025-04-11)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.2 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.12.0 
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0  LAPACK version 3.12.0
## 
## locale:
##  [1] LC_CTYPE=fr_FR.UTF-8       LC_NUMERIC=C               LC_TIME=fr_FR.UTF-8        LC_COLLATE=fr_FR.UTF-8    
##  [5] LC_MONETARY=fr_FR.UTF-8    LC_MESSAGES=fr_FR.UTF-8    LC_PAPER=fr_FR.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=fr_FR.UTF-8 LC_IDENTIFICATION=C       
## 
## time zone: Europe/Paris
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] data.table_1.17.0
## 
## loaded via a namespace (and not attached):
##  [1] crayon_1.5.3         knitr_1.50           cli_3.6.4            xfun_0.51            rlang_1.1.5         
##  [6] processx_3.8.6       torch_0.14.2         coro_1.1.0           bit_4.6.0            mlr3pipelines_0.7.2 
## [11] listenv_0.9.1        backports_1.5.0      ps_1.9.0             paradox_1.0.1        mlr3misc_0.16.0     
## [16] evaluate_1.0.3       mlr3_0.23.0          palmerpenguins_0.1.1 mlr3torch_0.2.1-9000 compiler_4.5.0      
## [21] codetools_0.2-20     Rcpp_1.0.14          mlr3tuning_1.3.0     bbotk_1.5.0          future_1.34.0       
## [26] digest_0.6.37        R6_2.6.1             curl_6.2.2           parallelly_1.43.0    parallel_4.5.0      
## [31] callr_3.7.6          magrittr_2.0.3       checkmate_2.3.2      uuid_1.2-1           tools_4.5.0         
## [36] withr_3.0.2          bit64_4.6.0-1        globals_0.16.3       lgr_0.4.4            remotes_2.5.0