PLNmodels icon indicating copy to clipboard operation
PLNmodels copied to clipboard

ELBO Diverged During Optimization with PLNnetwork and Torch Backend

Open EngineerDanny opened this issue 6 months ago • 8 comments

Problem Description

When using PLNnetwork() with the torch backend on the barents dataset, I encounter an ELBO divergence error during the optimization procedure.

Error Message

Error in (function (data, params, config)  : 
  The ELBO diverged during the optimization procedure.
Consider using:
* a different optimizer (current optimizer: RPROP)
* a smaller learning rate (current rate: 0.100)
with `control = PLN_param(backend = 'torch', config_optim = list(algorithm = ..., lr = ...))`

Minimal Reproducible Example

library(PLNmodels)
library(data.table)
library(glmnet)
set.seed(1)
data("barents", package = "PLNmodels")
task.dt <- as.data.table(t(barents$Abundance))
Y <- as.matrix(task.dt)
sample_names <- paste0("s", seq_len(nrow(Y)))
rownames(Y) <- sample_names
prepared_data <- prepare_data(counts = Y, 
                              covariates = data.frame(row.names = sample_names))
input_cols <- colnames(Y)[1:(ncol(Y)-1)]
output_col <- colnames(Y)[ncol(Y)]
fit <- PLNnetwork(
  Abundance ~ 1,
  data    = prepared_data,
  penalties = 0.1, 
  control = PLNnetwork_param(
    backend = "torch"
  ) 
)

Environment Information

  • R version: 4.5.0 (2025-04-11)
  • Platform: aarch64-apple-darwin20
  • Operating System: macOS Sequoia 15.5
  • PLNmodels version: 1.2.2
  • torch version: 0.14.2

EngineerDanny avatar Jul 21 '25 15:07 EngineerDanny

This happens when I log1p transform the abundance data before the fit. Ideally, both cases (raw count and log transformed data) should work.

library(PLNmodels)
library(data.table)
library(glmnet)
set.seed(1)
data("barents", package = "PLNmodels")
task.dt <- as.data.table(t(barents$Abundance))
task.dt[, names(task.dt) := lapply(.SD, log1p)]
Y <- as.matrix(task.dt)
sample_names <- paste0("s", seq_len(nrow(Y)))
rownames(Y) <- sample_names
prepared_data <- prepare_data(counts = Y, 
                              covariates = data.frame(row.names = sample_names))
input_cols <- colnames(Y)[1:(ncol(Y)-1)]
output_col <- colnames(Y)[ncol(Y)]
fit <- PLNnetwork(
  Abundance ~ 1,
  data    = prepared_data,
  penalties = 0.1, 
  control = PLNnetwork_param(
    backend = "torch"
  ) 
)

Output

Initialization...
 Adjusting 1 PLN with sparse inverse covariance estimation
	Joint optimization alternating gradient descent and graphical-lasso
Error in torch_tensor_cpp(data, dtype, device, requires_grad, pin_memory) : 
  R type not handled

EngineerDanny avatar Jul 21 '25 15:07 EngineerDanny

Dear user

Thank you for your interest in PLNModels. I would like to point out that the PLNmodels is built for count data as it relies heavily on the Poisson distribution and takes great care to preserve the count nature of the data. You can provide real-valued matrix to the PLN*() functions as we don't check for it but that is not the intended use.

If you want to build a network after log-transforming the data, you can use a number of methods that clr or log transform abundance data (SPRING, Magma, SPIEC-EASI, etc) many of which are available in NetCoMi.

mahendra-mariadassou avatar Jul 21 '25 15:07 mahendra-mariadassou

@mahendra-mariadassou That is true. But in this code, I used just count data -> https://github.com/PLN-team/PLNmodels/issues/155#issue-3249048537

EngineerDanny avatar Jul 21 '25 16:07 EngineerDanny

I was indeed a bit hasty. Let me investigate and get back to you.

mahendra-mariadassou avatar Jul 21 '25 20:07 mahendra-mariadassou

Ok, I fixed the type problem, but I still experience convergence problem so I am reopening.

  • lowering the learning rate for Rprop helps
  • but the overall optimization fails

So I need to further investigate.

jchiquet avatar Jul 24 '25 15:07 jchiquet

This might be a combination of RPROP not being appropriate in this case and the learning rate being too high. It seems to work with ADAM and a small learning rate (using the dev version of the package).

library(PLNmodels)
#> This is package 'PLNmodels' version 1.2.2-9100
#> Use future::plan(multicore/multisession) to speed up PLNPCA/PLNmixture/stability_selection.
set.seed(1)
data("barents", package = "PLNmodels")
fit <- PLNnetwork(
  Abundance ~ 1,
  data    = barents,
  penalties = 0.1, 
  control = PLNnetwork_param(
    backend = "torch", 
    config_optim = list(
      algorithm = "ADAM", 
      lr = 0.01)
  ) 
)
#> 
#>  Initialization...
#>  Adjusting 1 PLN with sparse inverse covariance estimation
#>  Joint optimization alternating gradient descent and graphical-lasso
#>  sparsifying penalty = 0.1 
#>  Post-treatments
#>  DONE!
fit$convergence
#>   param nb_param objective iterations status backend  convergence
#> 1   0.1      156   5068.79         20      3   torch 0.0001701754

Created on 2025-07-24 with reprex v2.1.1

mahendra-mariadassou avatar Jul 24 '25 15:07 mahendra-mariadassou

Everything is working on my side.

See below for the version of CUDA, cudnn that I am using, and other session information.

> library(PLNmodels)
> packageVersion("PLNmodels")
[1] ‘1.2.2.9100’
> library(PLNmodels)
> library(data.table)
data.table 1.17.6 using 10 threads (see ?getDTthreads).  Latest news: r-datatable.com
> library(glmnet)
Loading required package: Matrix
Loaded glmnet 4.1-9
> packageVersion("PLNmodels")
[1] ‘1.2.2.9100’
> set.seed(1)
> data("barents", package = "PLNmodels")
> task.dt <- as.data.table(t(barents$Abundance))
> task.dt[, names(task.dt) := lapply(.SD, log1p)]
> Y <- as.matrix(task.dt)
> sample_names <- paste0("s", seq_len(nrow(Y)))
> rownames(Y) <- sample_names
> prepared_data <- prepare_data(counts = Y, 
+                               covariates = data.frame(row.names = sample_names))
> input_cols <- colnames(Y)[1:(ncol(Y)-1)]
> output_col <- colnames(Y)[ncol(Y)]
> 
> fit_nlopt <- PLNnetwork(
+   Abundance ~ 1,
+   data    = prepared_data,
+   penalties = 0.1, 
+   control = PLNnetwork_param(
+     backend = "nlopt"
+   ) 
+ )

 Initialization...
 Adjusting 1 PLN with sparse inverse covariance estimation
	Joint optimization alternating gradient descent and graphical-lasso
	sparsifying penalty = 0.1 
 Post-treatments
 DONE!
> fit_torch <- PLNnetwork(
+     Abundance ~ 1,
+     data    = prepared_data,
+     penalties = 0.1, 
+     control = PLNnetwork_param(
+         backend = "torch",
+         config_optim = list(
+             algorithm = "ADAM", 
+             lr = 0.01)
+     ) 
+ )

 Initialization...
 Adjusting 1 PLN with sparse inverse covariance estimation
	Joint optimization alternating gradient descent and graphical-lasso
	sparsifying penalty = 0.1 
 Post-treatments
 DONE!
> 
> fit_torch$convergence
  param nb_param objective iterations status backend  convergence
1   0.1     1492  2628.201         20      3   torch 3.648069e-05
> fit_nlopt$convergence
  param nb_param status backend objective iterations  convergence
1   0.1     1544      3   nlopt  2511.806         20 0.0001000882
> plot(fit_nlopt, "diagnostic")
> plot(fit_torch, "diagnostic")
> fit_torch_rprop <- PLNnetwork(
+     Abundance ~ 1,
+     data    = prepared_data,
+     penalties = 0.1, 
+     control = PLNnetwork_param(
+         backend = "torch",
+         config_optim = list(
+             algorithm = "RPROP", 
+             lr = 0.01)
+     ) 
+ )

 Initialization...
 Adjusting 1 PLN with sparse inverse covariance estimation
	Joint optimization alternating gradient descent and graphical-lasso
	sparsifying penalty = 0.1 
 Post-treatments
 DONE!
> 
> torch::cuda_current_device()
[1] 0
> torch::cuda_is_available()
[1] TRUE
> torch::backends_cudnn_version()
[1] ‘90.1.0’
> torch::cuda_runtime_version()
[1] ‘12.4.0’
> fit_nlopt$convergence
  param nb_param status backend objective iterations  convergence
1   0.1     1544      3   nlopt  2511.806         20 0.0001000882
> fit_torch$convergence
  param nb_param objective iterations status backend  convergence
1   0.1     1492  2628.201         20      3   torch 3.648069e-05
> fit_torch_rprop$convergence
  param nb_param objective iterations status backend  convergence
1   0.1     1011  2627.833         20      3   torch 0.000332811
> sessionInfo()
R version 4.5.1 (2025-06-13)
Platform: x86_64-pc-linux-gnu
Running under: Ubuntu 24.04.2 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so;  LAPACK version 3.12.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=fr_FR.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=fr_FR.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=fr_FR.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=fr_FR.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Paris
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] glmnet_4.1-9         Matrix_1.7-3         data.table_1.17.6   
[4] PLNmodels_1.2.2-9100

loaded via a namespace (and not attached):
 [1] future_1.58.0       generics_0.1.4      tidyr_1.3.1        
 [4] shape_1.4.6.1       lattice_0.22-7      listenv_0.9.1      
 [7] digest_0.6.37       magrittr_2.0.3      grid_4.5.1         
[10] RColorBrewer_1.1-3  iterators_1.0.14    foreach_1.5.2      
[13] processx_3.8.6      survival_3.8-3      torch_0.15.0       
[16] ps_1.9.1            gridExtra_2.3       purrr_1.0.4        
[19] glassoFast_1.0.1    scales_1.4.0        coro_1.1.0         
[22] codetools_0.2-20    cli_3.6.5           rlang_1.1.6        
[25] parallelly_1.45.0   future.apply_1.20.0 splines_4.5.1      
[28] bit64_4.6.0-1       withr_3.0.2         corrplot_0.95      
[31] tools_4.5.1         parallel_4.5.1      nloptr_2.2.1       
[34] dplyr_1.1.4         ggplot2_3.5.2       globals_0.18.0     
[37] vctrs_0.6.5         R6_2.6.1            lifecycle_1.0.4    
[40] bit_4.6.0           MASS_7.3-65         pkgconfig_2.0.3    
[43] callr_3.7.6         pillar_1.11.0       gtable_0.3.6       
[46] glue_1.8.0          Rcpp_1.1.0          tibble_3.3.0       
[49] tidyselect_1.2.1    rstudioapi_0.17.1   farver_2.1.2       
[52] igraph_2.1.4        labeling_0.4.3      pscl_1.5.9         
[55] compiler_4.5.1     
>

jchiquet avatar Jul 29 '25 14:07 jchiquet

@mahendra-mariadassou , Daniel (a Toby Hocking's PhD student I am working with on PLN) is using Mac OS: do you use mac too, by any chance?

It would explain the common error (probably the equivalent of CUDA/torch library on MacOS use different version of RPROP/ADAM).

thanks.

J.

jchiquet avatar Jul 29 '25 14:07 jchiquet