Feature request for `case-weights`
Would it be possible to add support for case weights in TabNet?
This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.
(I will probably upsample the minority class in the meantime as an alternative approach.)
This would be the desired workflow:
library(tabnet)
library(tidymodels)
library(modeldata)
set.seed(123)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
set.seed(1)
tab_mod <- tabnet(epochs = 10) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
#> Error in `check_case_weights()`:
#> ! Case weights are not enabled by the underlying model implementation.
#> Backtrace:
#> ▆
#> 1. ├─generics::fit(tab_wf, train)
#> 2. └─workflows:::fit.workflow(tab_wf, train)
#> 3. └─workflows::.fit_model(workflow, control)
#> 4. ├─generics::fit(action_model, workflow = workflow, control = control)
#> 5. └─workflows:::fit.action_model(...)
#> 6. └─workflows:::fit_from_xy(spec, mold, case_weights, control_parsnip)
#> 7. ├─generics::fit_xy(...)
#> 8. └─parsnip::fit_xy.model_spec(...)
#> 9. └─parsnip:::check_case_weights(case_weights, object)
#> 10. └─rlang::abort("Case weights are not enabled by the underlying model implementation.")
Created on 2024-01-12 with reprex v2.0.2
Hello @cgoo4 I finally did it. Would you like to test it and report if this fits your need ? One way to install it is
pak::pkg_install("mlverse/tabnet@feature/case_weight")
Hi @cregouby - Thank you.
Ahead of trying it on my own data, I've made a quick test using the toy lending_club data. Untuned TabNet and XGBoost models, with and without case weights, show comparable results!
library(tabnet)
library(tidymodels)
library(modeldata)
library(patchwork)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
xgb_rec <- tab_rec |>
step_dummy(term, sub_grade, addr_state, verification_status, emp_length)
tab_mod <- tabnet(epochs = 100) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
xgb_mod <- boost_tree(trees = 100) |>
set_engine("xgboost") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
xgb_wf <- workflow() |>
add_model(xgb_mod) |>
add_recipe(xgb_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
xgb_fit <- xgb_wf |> fit(train)
tab_test <- tab_fit |> augment(test)
xgb_test <- xgb_fit |> augment(test)
p1 <- tab_test |>
pr_curve(Class, .pred_good, case_weights = case_wts) |>
autoplot() +
ggtitle("TabNet with Case Weights") +
theme(plot.title = element_text(size = 9))
p2 <- tab_test |>
pr_curve(Class, .pred_good) |>
autoplot() +
ggtitle("TabNet WITHOUT") +
theme(plot.title = element_text(size = 9))
p3 <- xgb_test |>
pr_curve(Class, .pred_good, case_weights = case_wts) |>
autoplot() +
ggtitle("XGBoost with Case Weights") +
theme(plot.title = element_text(size = 9))
p4 <- xgb_test |>
pr_curve(Class, .pred_good) |>
autoplot() +
ggtitle("XGBoost WITHOUT") +
theme(plot.title = element_text(size = 9))
p1 + p2 + p3 + p4

Created on 2024-02-18 with reprex v2.1.0
In 0.6.0.9000 I'm getting the message Configured weights will not be used:
(It's the same example as per above where the weights were being passed along from the workflow.)
library(tabnet)
library(tidymodels)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
tab_mod <- tabnet(epochs = 10) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
#> Configured `weights` will not be used
tab_test <- tab_fit |> augment(test)
Created on 2024-08-09 with reprex v2.1.1
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.4.1 (2024-06-14)
#> os macOS Sonoma 14.5
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/London
#> date 2024-08-09
#> pandoc 3.2.1 @ /opt/homebrew/bin/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> backports 1.5.0 2024-05-23 [1] CRAN (R 4.4.0)
#> bit 4.0.5 2022-11-15 [1] CRAN (R 4.4.0)
#> bit64 4.0.5 2020-08-30 [1] CRAN (R 4.4.0)
#> broom * 1.0.6 2024-05-17 [1] CRAN (R 4.4.0)
#> callr 3.7.6 2024-03-25 [1] CRAN (R 4.4.0)
#> class 7.3-22 2023-05-03 [2] CRAN (R 4.4.1)
#> cli 3.6.3 2024-06-21 [1] CRAN (R 4.4.0)
#> codetools 0.2-20 2024-03-31 [2] CRAN (R 4.4.1)
#> colorspace 2.1-1 2024-07-26 [1] CRAN (R 4.4.0)
#> coro 1.0.4 2024-03-11 [1] CRAN (R 4.4.0)
#> data.table 1.15.4 2024-03-30 [1] CRAN (R 4.4.0)
#> dials * 1.3.0 2024-07-30 [1] CRAN (R 4.4.0)
#> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.4.0)
#> digest 0.6.36 2024-06-23 [1] CRAN (R 4.4.0)
#> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.4.0)
#> evaluate 0.24.0 2024-06-10 [1] CRAN (R 4.4.0)
#> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.4.0)
#> fastmap 1.2.0 2024-05-15 [1] CRAN (R 4.4.0)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.4.0)
#> fs 1.6.4 2024-04-25 [1] CRAN (R 4.4.0)
#> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.4.0)
#> future 1.34.0 2024-07-29 [1] CRAN (R 4.4.0)
#> future.apply 1.11.2 2024-03-28 [1] CRAN (R 4.4.0)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.4.0)
#> ggplot2 * 3.5.1 2024-04-23 [1] CRAN (R 4.4.0)
#> globals 0.16.3 2024-03-08 [1] CRAN (R 4.4.0)
#> glue 1.7.0 2024-01-09 [1] CRAN (R 4.4.0)
#> gower 1.0.1 2022-12-22 [1] CRAN (R 4.4.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.4.0)
#> gtable 0.3.5 2024-04-22 [1] CRAN (R 4.4.0)
#> hardhat 1.4.0 2024-06-02 [1] CRAN (R 4.4.0)
#> htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.4.0)
#> infer * 1.0.7 2024-03-25 [1] CRAN (R 4.4.0)
#> ipred 0.9-15 2024-07-18 [1] CRAN (R 4.4.0)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.4.0)
#> jsonlite 1.8.8 2023-12-04 [1] CRAN (R 4.4.0)
#> knitr 1.48 2024-07-07 [1] CRAN (R 4.4.0)
#> lattice 0.22-6 2024-03-20 [2] CRAN (R 4.4.1)
#> lava 1.8.0 2024-03-05 [1] CRAN (R 4.4.0)
#> lhs 1.2.0 2024-06-30 [1] CRAN (R 4.4.0)
#> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.4.0)
#> listenv 0.9.1 2024-01-29 [1] CRAN (R 4.4.0)
#> lubridate 1.9.3 2023-09-27 [1] CRAN (R 4.4.0)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.4.0)
#> MASS 7.3-60.2 2024-04-26 [2] CRAN (R 4.4.1)
#> Matrix 1.7-0 2024-04-26 [2] CRAN (R 4.4.1)
#> modeldata * 1.4.0 2024-06-19 [1] CRAN (R 4.4.0)
#> munsell 0.5.1 2024-04-01 [1] CRAN (R 4.4.0)
#> nnet 7.3-19 2023-05-03 [2] CRAN (R 4.4.1)
#> parallelly 1.38.0 2024-07-27 [1] CRAN (R 4.4.0)
#> parsnip * 1.2.1 2024-03-22 [1] CRAN (R 4.4.0)
#> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.4.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.4.0)
#> processx 3.8.4 2024-03-16 [1] CRAN (R 4.4.0)
#> prodlim 2024.06.25 2024-06-24 [1] CRAN (R 4.4.0)
#> ps 1.7.7 2024-07-02 [1] CRAN (R 4.4.0)
#> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.4.0)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.4.0)
#> Rcpp 1.0.13 2024-07-17 [1] CRAN (R 4.4.0)
#> recipes * 1.1.0 2024-07-04 [1] CRAN (R 4.4.0)
#> reprex 2.1.1 2024-07-06 [1] CRAN (R 4.4.0)
#> rlang 1.1.4 2024-06-04 [1] CRAN (R 4.4.0)
#> rmarkdown 2.27 2024-05-17 [1] CRAN (R 4.4.0)
#> rpart 4.1.23 2023-12-05 [2] CRAN (R 4.4.1)
#> rsample * 1.2.1 2024-03-25 [1] CRAN (R 4.4.0)
#> rstudioapi 0.16.0 2024-03-24 [1] CRAN (R 4.4.0)
#> safetensors 0.1.2 2023-09-12 [1] CRAN (R 4.4.0)
#> scales * 1.3.0 2023-11-28 [1] CRAN (R 4.4.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.4.0)
#> survival 3.6-4 2024-04-24 [2] CRAN (R 4.4.1)
#> tabnet * 0.6.0.9000 2024-08-09 [1] Github (mlverse/tabnet@c8c82d2)
#> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.4.0)
#> tidymodels * 1.2.0 2024-03-25 [1] CRAN (R 4.4.0)
#> tidyr * 1.3.1 2024-01-24 [1] CRAN (R 4.4.0)
#> tidyselect 1.2.1 2024-03-11 [1] CRAN (R 4.4.0)
#> timechange 0.3.0 2024-01-18 [1] CRAN (R 4.4.0)
#> timeDate 4032.109 2023-12-14 [1] CRAN (R 4.4.0)
#> torch 0.13.0 2024-05-21 [1] CRAN (R 4.4.0)
#> tune * 1.2.1 2024-04-18 [1] CRAN (R 4.4.0)
#> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.4.0)
#> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.4.0)
#> withr 3.0.1 2024-07-31 [1] CRAN (R 4.4.0)
#> workflows * 1.1.4 2024-02-19 [1] CRAN (R 4.4.0)
#> workflowsets * 1.1.0 2024-03-21 [1] CRAN (R 4.4.0)
#> xfun 0.46 2024-07-18 [1] CRAN (R 4.4.0)
#> yaml 2.3.10 2024-07-26 [1] CRAN (R 4.4.0)
#> yardstick * 1.3.1 2024-03-21 [1] CRAN (R 4.4.0)
#> zeallot 0.1.0 2018-01-28 [1] CRAN (R 4.4.0)
#>
#> [1] /Users/carlgoodwin/Library/R/arm64/4.4/library
#> [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#>
#> ──────────────────────────────────────────────────────────────────────────────
Hello @cgoo4,
I added the message on purpose and it is maybe misleading. The meaning is 'tabnet model will be fit without using the case_weights variable.' as this is the actual usage of case_weights variable by tabnet, they are let appart for later-on usage by other downstream tydimodel packages.
Any proposal for a more informative message ?
Hi @cregouby - Thank you for clarifying.
If it's possible to set case_weights more than one way, e.g. in a tidymodels workflow() and also in tabnet_fit(), then maybe the message could say the latter is being overridden by the former?