index and key are not present when using predict
I have future known values so I need to pass the data set to new_data in predict() but it only gives me .pred_lower, .pred and .pred_upper. No index (date) or key (id) present in the output which is important as the data set contains multiple time series.
> forecasts_default <- predict(fitted_default, new_data = test, past_data = train)
> forecasts_default
# A tibble: 520 × 3
.pred_lower .pred .pred_upper
<dbl> <dbl> <dbl>
1 1498. 2242. 2874.
2 1373. 1791. 2134.
3 1259. 1671. 2016.
4 1142. 1463. 1730.
5 1619. 2090. 2485.
6 467. 3079. 5370.
7 1361. 1743. 2070.
8 2375. 3815. 5073.
9 2591. 3305. 3896.
10 5128. 7600. 9728.
# … with 510 more rows
Using generics::forecast() instead and skipping the known future data gives the desired output (i.e. containing date and id):
# A tibble: 520 × 5
date id .pred_lower .pred .pred_upper
* <date> <chr> <dbl> <dbl> <dbl>
1 2021-08-09 oes_13078 1498. 2242. 2874.
2 2021-08-16 oes_13078 1598. 2323. 2971.
3 2021-08-23 oes_13078 1510. 2379. 3257.
4 2021-08-30 oes_13078 1572. 2410. 3216.
5 2021-09-06 oes_13078 1356. 2280. 3259.
6 2021-09-13 oes_13078 1359. 2090. 2809.
7 2021-09-20 oes_13078 1383. 2130. 2771.
8 2021-09-27 oes_13078 1324. 2074. 2847.
9 2021-10-04 oes_13078 1355. 2254. 2973.
10 2021-10-11 oes_13078 1371. 2068. 2524.
# … with 510 more rows
Do I need to do anything for the predict() function to return the index and the key?
I guess my question is how do I use known and static information when creating forecasts?
Hello @vidarsumo,
Definitively the generics::forecast() shall be used.
To define the covariates, have you used the Getting started web page syntax ? Does the 'specifying the covariate' section fail with your experiment ?
Hope it helps
If I'm not mistaken, then generics::forecast() does not accept new_data while predict() does.
The only thing that worked was to use predict() and then bind_cols() to get the id, Date, etc.
suppressPackageStartupMessages(library(tidymodels))
library(tft)
set.seed(1)
torch::torch_manual_seed(1)
# Preparing data
data_tbl <- timetk::walmart_sales_weekly %>%
select(id, Dept, Date, Weekly_Sales, IsHoliday) %>%
mutate(
Dept = paste0("Dept_", Dept),
IsHoliday = ifelse(IsHoliday, "yes", "no"))
fit_data <- data_tbl %>%
filter(Date <= "2012-08-03")
future_data <- data_tbl %>%
filter(Date > "2012-08-03") %>%
mutate(Weekly_Sales = NA_real_)
# TFT
rec <- recipe(Weekly_Sales ~ ., data = fit_data) %>%
timetk::step_timeseries_signature(Date) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors())
spec <- tft_dataset_spec(rec, fit_data) %>%
spec_covariate_index(Date) %>%
spec_covariate_key(id) %>%
spec_covariate_known(starts_with("Date_"), IsHoliday) %>%
spec_covariate_static(Dept) %>%
spec_time_splits(lookback = 52, horizon = 12) %>%
prep()
tft_model <- temporal_fusion_transformer(spec)
fitted <- tft_model %>%
fit(
transform(spec),
epochs = 1,
verbose = TRUE,
dataloader_options = list(batch_size = 64, num_workers = 4)
)
# Using forecast() without new_data
generics::forecast(fitted, past_data = fit_data)
Error in `adjust_new_data()`:
! Known or static variable is missing from `new_data`.
✖ Check for `Dept`.
# Using forecast() with new_data
generics::forecast(fitted, past_data = fit_data, new_data = future_data)
Error in forecast.tft_result(fitted, past_data = fit_data, new_data = future_data) :
unused argument (new_data = future_data)
# Using predict() (no id, date etc. present)
predict(object = fitted, new_data = future_data, past_data = fit_data)
# A tibble: 84 × 3
.pred_lower .pred .pred_upper
<dbl> <dbl> <dbl>
1 13556. 22793. 41542.
2 12762. 20972. 40915.
3 11930. 19535. 39995.
4 12875. 21650. 43980.
5 12946. 20197. 34923.
6 14671. 21054. 36676.
7 14294. 20423. 36560.
8 14550. 20447. 38125.
9 12553. 20296. 28059.
10 13694. 21042. 32854.
# Using predict() with bind_cols() does the trick.
predict(object = fitted, new_data = future_data, past_data = fit_data) %>%
bind_cols(future_data %>% select(-Weekly_Sales))
Am I doing something wrong here?
Hello @vidarsumo
Sorry, my mistake, you are right :
-
forcast()provides keys and index, but is documented to "can only be used if the model object doesn't includeknownpredictors" -
predict()usesknownpredictor innew_data =but source code removes keys and index variables just before releasing the result at https://github.com/mlverse/tft/blob/b8f3b115bba6d31cae2ba361a850df4e39669088/R/predict.R#L62. Maybe @dfalbel knows a design pattern that prevent to add additionnal variables in thepredict()output via a switch parameter likeall_vars = FALSE? Anyway, it would be easy to modify it, so maybe you can propose a pull-request ?
Will it be possible to predict complete data like predict(object = fitted, new_data = data_tbl, past_data = fit_data)? , if yes, please let me know. It will be very helpful to the accuracy assessment of the fitted model. The above line generates an error, if someone helps regarding this is highly appreciated.
Hello @Ujjwal4CULS
I cannot reproduce your issue with the example documented here. Could you please open a dedicated issue with a Reproductible Example
Sorry for my incomplete information. For example, on the above code, can it possible to predict application on data_tbl like this
suppressPackageStartupMessages(library(tidymodels)) library(tft) set.seed(1) torch::torch_manual_seed(1)
Preparing data
data_tbl <- timetk::walmart_sales_weekly %>% select(id, Dept, Date, Weekly_Sales, IsHoliday) %>% mutate( Dept = paste0("Dept_", Dept), IsHoliday = ifelse(IsHoliday, "yes", "no"))
fit_data <- data_tbl %>% filter(Date <= "2012-08-03")
future_data <- data_tbl %>% filter(Date > "2012-08-03") %>% mutate(Weekly_Sales = NA_real_)
TFT
rec <- recipe(Weekly_Sales ~ ., data = fit_data) %>% timetk::step_timeseries_signature(Date) %>% step_zv(all_predictors()) %>% step_normalize(all_numeric_predictors())
spec <- tft_dataset_spec(rec, fit_data) %>% spec_covariate_index(Date) %>% spec_covariate_key(id) %>% spec_covariate_known(starts_with("Date_"), IsHoliday) %>% spec_covariate_static(Dept) %>% spec_time_splits(lookback = 52, horizon = 12) %>% prep()
tft_model <- temporal_fusion_transformer(spec)
fitted <- tft_model %>% fit( transform(spec), epochs = 1, verbose = TRUE, dataloader_options = list(batch_size = 64, num_workers = 4) ) predict(object = fitted, new_data = data_tbl, past_data = fit_data)