tft icon indicating copy to clipboard operation
tft copied to clipboard

index and key are not present when using predict

Open vidarsumo opened this issue 3 years ago • 7 comments

I have future known values so I need to pass the data set to new_data in predict() but it only gives me .pred_lower, .pred and .pred_upper. No index (date) or key (id) present in the output which is important as the data set contains multiple time series.

> forecasts_default <- predict(fitted_default, new_data = test, past_data = train)
> forecasts_default
# A tibble: 520 × 3
   .pred_lower .pred .pred_upper
         <dbl> <dbl>       <dbl>
 1       1498. 2242.       2874.
 2       1373. 1791.       2134.
 3       1259. 1671.       2016.
 4       1142. 1463.       1730.
 5       1619. 2090.       2485.
 6        467. 3079.       5370.
 7       1361. 1743.       2070.
 8       2375. 3815.       5073.
 9       2591. 3305.       3896.
10       5128. 7600.       9728.
# … with 510 more rows

Using generics::forecast() instead and skipping the known future data gives the desired output (i.e. containing date and id):

# A tibble: 520 × 5
   date       id        .pred_lower .pred .pred_upper
 * <date>     <chr>           <dbl> <dbl>       <dbl>
 1 2021-08-09 oes_13078       1498. 2242.       2874.
 2 2021-08-16 oes_13078       1598. 2323.       2971.
 3 2021-08-23 oes_13078       1510. 2379.       3257.
 4 2021-08-30 oes_13078       1572. 2410.       3216.
 5 2021-09-06 oes_13078       1356. 2280.       3259.
 6 2021-09-13 oes_13078       1359. 2090.       2809.
 7 2021-09-20 oes_13078       1383. 2130.       2771.
 8 2021-09-27 oes_13078       1324. 2074.       2847.
 9 2021-10-04 oes_13078       1355. 2254.       2973.
10 2021-10-11 oes_13078       1371. 2068.       2524.
# … with 510 more rows

Do I need to do anything for the predict() function to return the index and the key?

vidarsumo avatar Aug 07 '22 23:08 vidarsumo

I guess my question is how do I use known and static information when creating forecasts?

vidarsumo avatar Aug 09 '22 00:08 vidarsumo

Hello @vidarsumo,

Definitively the generics::forecast() shall be used.

To define the covariates, have you used the Getting started web page syntax ? Does the 'specifying the covariate' section fail with your experiment ?

Hope it helps

cregouby avatar Aug 17 '22 14:08 cregouby

If I'm not mistaken, then generics::forecast() does not accept new_data while predict() does. The only thing that worked was to use predict() and then bind_cols() to get the id, Date, etc.

suppressPackageStartupMessages(library(tidymodels))
library(tft)
set.seed(1)
torch::torch_manual_seed(1)



# Preparing data
data_tbl <- timetk::walmart_sales_weekly %>%
  select(id, Dept, Date, Weekly_Sales, IsHoliday) %>% 
  mutate(
    Dept = paste0("Dept_", Dept),
    IsHoliday = ifelse(IsHoliday, "yes", "no"))

fit_data <- data_tbl %>% 
  filter(Date <= "2012-08-03")

future_data <- data_tbl %>% 
  filter(Date > "2012-08-03") %>% 
  mutate(Weekly_Sales = NA_real_)


# TFT
rec <- recipe(Weekly_Sales ~ ., data = fit_data) %>%
  timetk::step_timeseries_signature(Date) %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_numeric_predictors())


spec <- tft_dataset_spec(rec, fit_data) %>%
  spec_covariate_index(Date) %>%
  spec_covariate_key(id) %>%
  spec_covariate_known(starts_with("Date_"), IsHoliday) %>%
  spec_covariate_static(Dept) %>% 
  spec_time_splits(lookback = 52, horizon = 12) %>%
  prep()

tft_model <- temporal_fusion_transformer(spec)

fitted <- tft_model %>%
  fit(
    transform(spec),
    epochs = 1,
    verbose = TRUE,
    dataloader_options = list(batch_size = 64, num_workers = 4)
  )

# Using forecast() without new_data
generics::forecast(fitted, past_data = fit_data)
Error in `adjust_new_data()`:
! Known or static variable is missing from `new_data`.
✖ Check for `Dept`.

# Using forecast() with new_data
generics::forecast(fitted, past_data = fit_data, new_data = future_data)
Error in forecast.tft_result(fitted, past_data = fit_data, new_data = future_data) : 
  unused argument (new_data = future_data)

# Using predict() (no id, date etc. present)
predict(object = fitted, new_data = future_data, past_data = fit_data)
# A tibble: 84 × 3
   .pred_lower  .pred .pred_upper
         <dbl>  <dbl>       <dbl>
 1      13556. 22793.      41542.
 2      12762. 20972.      40915.
 3      11930. 19535.      39995.
 4      12875. 21650.      43980.
 5      12946. 20197.      34923.
 6      14671. 21054.      36676.
 7      14294. 20423.      36560.
 8      14550. 20447.      38125.
 9      12553. 20296.      28059.
10      13694. 21042.      32854.

# Using predict() with bind_cols() does the trick.
predict(object = fitted, new_data = future_data, past_data = fit_data) %>% 
  bind_cols(future_data %>% select(-Weekly_Sales))

Am I doing something wrong here?

vidarsumo avatar Aug 20 '22 22:08 vidarsumo

Hello @vidarsumo

Sorry, my mistake, you are right :

  1. forcast() provides keys and index, but is documented to "can only be used if the model object doesn't include known predictors"
  2. predict() uses known predictor in new_data = but source code removes keys and index variables just before releasing the result at https://github.com/mlverse/tft/blob/b8f3b115bba6d31cae2ba361a850df4e39669088/R/predict.R#L62. Maybe @dfalbel knows a design pattern that prevent to add additionnal variables in the predict() output via a switch parameter like all_vars = FALSE ? Anyway, it would be easy to modify it, so maybe you can propose a pull-request ?

cregouby avatar Aug 25 '22 17:08 cregouby

Will it be possible to predict complete data like predict(object = fitted, new_data = data_tbl, past_data = fit_data)? , if yes, please let me know. It will be very helpful to the accuracy assessment of the fitted model. The above line generates an error, if someone helps regarding this is highly appreciated.

Ujjwal4CULS avatar Mar 25 '23 21:03 Ujjwal4CULS

Hello @Ujjwal4CULS

I cannot reproduce your issue with the example documented here. Could you please open a dedicated issue with a Reproductible Example

cregouby avatar Mar 28 '23 11:03 cregouby

Sorry for my incomplete information. For example, on the above code, can it possible to predict application on data_tbl like this

suppressPackageStartupMessages(library(tidymodels)) library(tft) set.seed(1) torch::torch_manual_seed(1)

Preparing data

data_tbl <- timetk::walmart_sales_weekly %>% select(id, Dept, Date, Weekly_Sales, IsHoliday) %>% mutate( Dept = paste0("Dept_", Dept), IsHoliday = ifelse(IsHoliday, "yes", "no"))

fit_data <- data_tbl %>% filter(Date <= "2012-08-03")

future_data <- data_tbl %>% filter(Date > "2012-08-03") %>% mutate(Weekly_Sales = NA_real_)

TFT

rec <- recipe(Weekly_Sales ~ ., data = fit_data) %>% timetk::step_timeseries_signature(Date) %>% step_zv(all_predictors()) %>% step_normalize(all_numeric_predictors())

spec <- tft_dataset_spec(rec, fit_data) %>% spec_covariate_index(Date) %>% spec_covariate_key(id) %>% spec_covariate_known(starts_with("Date_"), IsHoliday) %>% spec_covariate_static(Dept) %>% spec_time_splits(lookback = 52, horizon = 12) %>% prep()

tft_model <- temporal_fusion_transformer(spec)

fitted <- tft_model %>% fit( transform(spec), epochs = 1, verbose = TRUE, dataloader_options = list(batch_size = 64, num_workers = 4) ) predict(object = fitted, new_data = data_tbl, past_data = fit_data)

Ujjwal4CULS avatar Mar 28 '23 15:03 Ujjwal4CULS