tabnet icon indicating copy to clipboard operation
tabnet copied to clipboard

Feature request for `case-weights`

Open cgoo4 opened this issue 2 years ago • 5 comments

Would it be possible to add support for case weights in TabNet?

This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.

(I will probably upsample the minority class in the meantime as an alternative approach.)

This would be the desired workflow:

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

set.seed(1)

tab_mod <- tabnet(epochs = 10) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
#> Error in `check_case_weights()`:
#> ! Case weights are not enabled by the underlying model implementation.
#> Backtrace:
#>      ▆
#>   1. ├─generics::fit(tab_wf, train)
#>   2. └─workflows:::fit.workflow(tab_wf, train)
#>   3.   └─workflows::.fit_model(workflow, control)
#>   4.     ├─generics::fit(action_model, workflow = workflow, control = control)
#>   5.     └─workflows:::fit.action_model(...)
#>   6.       └─workflows:::fit_from_xy(spec, mold, case_weights, control_parsnip)
#>   7.         ├─generics::fit_xy(...)
#>   8.         └─parsnip::fit_xy.model_spec(...)
#>   9.           └─parsnip:::check_case_weights(case_weights, object)
#>  10.             └─rlang::abort("Case weights are not enabled by the underlying model implementation.")

Created on 2024-01-12 with reprex v2.0.2

cgoo4 avatar Jan 12 '24 16:01 cgoo4

Hello @cgoo4 I finally did it. Would you like to test it and report if this fits your need ? One way to install it is

pak::pkg_install("mlverse/tabnet@feature/case_weight")

cregouby avatar Feb 18 '24 10:02 cregouby

Hi @cregouby - Thank you.

Ahead of trying it on my own data, I've made a quick test using the toy lending_club data. Untuned TabNet and XGBoost models, with and without case weights, show comparable results!

library(tabnet)
library(tidymodels)
library(modeldata)
library(patchwork)

data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

xgb_rec <- tab_rec |> 
  step_dummy(term, sub_grade, addr_state, verification_status, emp_length)

tab_mod <- tabnet(epochs = 100) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

xgb_mod <- boost_tree(trees = 100) |> 
  set_engine("xgboost") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

xgb_wf <- workflow() |> 
  add_model(xgb_mod) |> 
  add_recipe(xgb_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
xgb_fit <- xgb_wf |> fit(train)

tab_test <- tab_fit |> augment(test)
xgb_test <- xgb_fit |> augment(test)

p1 <- tab_test |> 
  pr_curve(Class, .pred_good, case_weights = case_wts) |> 
  autoplot() +
  ggtitle("TabNet with Case Weights") +
  theme(plot.title = element_text(size = 9))

p2 <- tab_test |> 
  pr_curve(Class, .pred_good) |> 
  autoplot() +
  ggtitle("TabNet WITHOUT") +
  theme(plot.title = element_text(size = 9))

p3 <- xgb_test |> 
  pr_curve(Class, .pred_good, case_weights = case_wts) |> 
  autoplot() +
  ggtitle("XGBoost with Case Weights") +
  theme(plot.title = element_text(size = 9))

p4 <- xgb_test |> 
  pr_curve(Class, .pred_good) |> 
  autoplot() +
  ggtitle("XGBoost WITHOUT") +
  theme(plot.title = element_text(size = 9))

p1 + p2 + p3 + p4

Created on 2024-02-18 with reprex v2.1.0

cgoo4 avatar Feb 18 '24 12:02 cgoo4

In 0.6.0.9000 I'm getting the message Configured weights will not be used:

(It's the same example as per above where the weights were being passed along from the workflow.)

library(tabnet)
library(tidymodels)

data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

tab_mod <- tabnet(epochs = 10) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
#> Configured `weights` will not be used

tab_test <- tab_fit |> augment(test)

Created on 2024-08-09 with reprex v2.1.1

Session info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 (2024-06-14)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/London
#>  date     2024-08-09
#>  pandoc   3.2.1 @ /opt/homebrew/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  bit            4.0.5      2022-11-15 [1] CRAN (R 4.4.0)
#>  bit64          4.0.5      2020-08-30 [1] CRAN (R 4.4.0)
#>  broom        * 1.0.6      2024-05-17 [1] CRAN (R 4.4.0)
#>  callr          3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.4.1)
#>  cli            3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>  codetools      0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>  colorspace     2.1-1      2024-07-26 [1] CRAN (R 4.4.0)
#>  coro           1.0.4      2024-03-11 [1] CRAN (R 4.4.0)
#>  data.table     1.15.4     2024-03-30 [1] CRAN (R 4.4.0)
#>  dials        * 1.3.0      2024-07-30 [1] CRAN (R 4.4.0)
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.4.0)
#>  digest         0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.4.0)
#>  evaluate       0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.4.0)
#>  fastmap        1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.4.0)
#>  fs             1.6.4      2024-04-25 [1] CRAN (R 4.4.0)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.4.0)
#>  future         1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  future.apply   1.11.2     2024-03-28 [1] CRAN (R 4.4.0)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.4.0)
#>  ggplot2      * 3.5.1      2024-04-23 [1] CRAN (R 4.4.0)
#>  globals        0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.4.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.4.0)
#>  gtable         0.3.5      2024-04-22 [1] CRAN (R 4.4.0)
#>  hardhat        1.4.0      2024-06-02 [1] CRAN (R 4.4.0)
#>  htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  infer        * 1.0.7      2024-03-25 [1] CRAN (R 4.4.0)
#>  ipred          0.9-15     2024-07-18 [1] CRAN (R 4.4.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.4.0)
#>  jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>  knitr          1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>  lattice        0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>  lava           1.8.0      2024-03-05 [1] CRAN (R 4.4.0)
#>  lhs            1.2.0      2024-06-30 [1] CRAN (R 4.4.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv        0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.4.0)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  MASS           7.3-60.2   2024-04-26 [2] CRAN (R 4.4.1)
#>  Matrix         1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>  modeldata    * 1.4.0      2024-06-19 [1] CRAN (R 4.4.0)
#>  munsell        0.5.1      2024-04-01 [1] CRAN (R 4.4.0)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.4.1)
#>  parallelly     1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>  parsnip      * 1.2.1      2024-03-22 [1] CRAN (R 4.4.0)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.4.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  processx       3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>  prodlim        2024.06.25 2024-06-24 [1] CRAN (R 4.4.0)
#>  ps             1.7.7      2024-07-02 [1] CRAN (R 4.4.0)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.4.0)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>  Rcpp           1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>  recipes      * 1.1.0      2024-07-04 [1] CRAN (R 4.4.0)
#>  reprex         2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  rlang          1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>  rmarkdown      2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>  rpart          4.1.23     2023-12-05 [2] CRAN (R 4.4.1)
#>  rsample      * 1.2.1      2024-03-25 [1] CRAN (R 4.4.0)
#>  rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>  safetensors    0.1.2      2023-09-12 [1] CRAN (R 4.4.0)
#>  scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.4.0)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  survival       3.6-4      2024-04-24 [2] CRAN (R 4.4.1)
#>  tabnet       * 0.6.0.9000 2024-08-09 [1] Github (mlverse/tabnet@c8c82d2)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.4.0)
#>  tidymodels   * 1.2.0      2024-03-25 [1] CRAN (R 4.4.0)
#>  tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.4.0)
#>  tidyselect     1.2.1      2024-03-11 [1] CRAN (R 4.4.0)
#>  timechange     0.3.0      2024-01-18 [1] CRAN (R 4.4.0)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.4.0)
#>  torch          0.13.0     2024-05-21 [1] CRAN (R 4.4.0)
#>  tune         * 1.2.1      2024-04-18 [1] CRAN (R 4.4.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.4.0)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  withr          3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>  workflows    * 1.1.4      2024-02-19 [1] CRAN (R 4.4.0)
#>  workflowsets * 1.1.0      2024-03-21 [1] CRAN (R 4.4.0)
#>  xfun           0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>  yaml           2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#>  yardstick    * 1.3.1      2024-03-21 [1] CRAN (R 4.4.0)
#>  zeallot        0.1.0      2018-01-28 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/carlgoodwin/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

cgoo4 avatar Aug 09 '24 10:08 cgoo4

Hello @cgoo4,

I added the message on purpose and it is maybe misleading. The meaning is 'tabnet model will be fit without using the case_weights variable.' as this is the actual usage of case_weights variable by tabnet, they are let appart for later-on usage by other downstream tydimodel packages.

Any proposal for a more informative message ?

cregouby avatar Aug 13 '24 06:08 cregouby

Hi @cregouby - Thank you for clarifying.

If it's possible to set case_weights more than one way, e.g. in a tidymodels workflow() and also in tabnet_fit(), then maybe the message could say the latter is being overridden by the former?

cgoo4 avatar Aug 13 '24 10:08 cgoo4