tidypredict icon indicating copy to clipboard operation
tidypredict copied to clipboard

[bug report] prediction difference between R XGBoost model and translated SQL when set base_score

Open JiaxiangBU opened this issue 6 years ago • 1 comments

I choose dataset mtcars to make a reproducible example below.

library(xgboost)
#> Warning: 程辑包'xgboost'是用R版本3.6.1 来建造的
library(tidyverse)
#> Registered S3 methods overwritten by 'ggplot2':
#>   method         from 
#>   [.quosures     rlang
#>   c.quosures     rlang
#>   print.quosures rlang
#> Warning: 程辑包'dplyr'是用R版本3.6.1 来建造的
train_data <- mtcars %>% 
    rename(y = am)
dtrain <- 
    xgb.DMatrix(
        data = as.matrix(
            train_data %>% select(-y)
        )
        ,label = train_data$y
    )
xgb_model <- xgb.train(
    data=dtrain,
    nround=10,
    seed = 1, 
    max_depth = 1,
    objective = "binary:logistic",
    base_score = mean(train_data$y) # fix uncalibration problem
)
pred_from_model <- predict(xgb_model, newdata = dtrain)
library(sqldf)
#> Warning: 程辑包'sqldf'是用R版本3.6.1 来建造的
#> 载入需要的程辑包:gsubfn
#> Warning: 程辑包'gsubfn'是用R版本3.6.1 来建造的
#> 载入需要的程辑包:proto
#> Warning: 程辑包'proto'是用R版本3.6.1 来建造的
#> 载入需要的程辑包:RSQLite
#> Warning: 程辑包'RSQLite'是用R版本3.6.1 来建造的
library(tidypredict)
#> Warning: 程辑包'tidypredict'是用R版本3.6.1 来建造的
pred_from_tidypredict <- 
tidypredict_sql(xgb_model, dbplyr::simulate_dbi()) %>% 
  paste("select ",.," from mtcars") %>% 
  # cat
    sqldf() %>% 
  pull
(pred_from_model-pred_from_tidypredict) %>% abs %>% mean
#> [1] 0.04692561

Created on 2019-10-20 by the reprex package (v0.3.0)

JiaxiangBU avatar Oct 20 '19 06:10 JiaxiangBU

I create a pull request to solve this problem https://github.com/tidymodels/tidypredict/pull/66

JiaxiangBU avatar Oct 20 '19 07:10 JiaxiangBU