This repository was archived by the owner on Jun 14, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 31
This repository was archived by the owner on Jun 14, 2023. It is now read-only.
Again Error for Unusual Split of Data into Test and Train Sets #260
Copy link
Copy link
Open
Description
Dear All,
I am going back to tidymodels and I dusted off the example discussed at https://github.com/tidymodels/tidymodels.org/issues/198 .
Please have a look at the reprex.
Now I get a new error/warning (and I am sure I did not have this specific issue in the past) which seems to be due to the train data consisting only of one data point, but I am not 100% sure of this nor do I know what to do if this is the case (I really need this unusual data split).
I have two questions
- can anyone tell me what to do to fix my code?
- given that my data has a strong time component (in df_ini there is a "year" column), I think I should use a different resampling technique (see https://www.tmwr.org/resampling.html#rolling ).
It must be a one-liner, but I am experiencing some issues. Can anyone show me how to implement in my example (once it has been fixed) the rolling forecasting origin resampling, e.g. non cumulative, analysis size of eight samples (eight years) and an assessment set size of two (two years)?
Thanks a lot!
library(tidymodels)
tidymodels_prefer()
df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008,
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018),
capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3,
3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6,
3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9,
7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1,
11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7,
17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605,
19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6,
20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2,
29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53,
2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1,
2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57,
2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1,
2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93,
389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91,
392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63,
515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98,
524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59,
1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76,
1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83,
2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75,
3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85,
3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59,
4171.72), employment_total_lag_1 = c(14509.58, 15127.99,
15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53,
16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6,
17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7,
220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9,
288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1,
344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8,
169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1,
213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6,
249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6,
71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1,
91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3,
129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4,
28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9,
39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5,
59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7,
2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1,
48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029,
56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4,
71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4,
42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3,
51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2,
8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1,
10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6,
9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2,
12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9,
13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6,
16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2,
24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6,
38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9,
49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4,
197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6,
262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1,
307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2,
3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2,
3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3,
8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9,
13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5,
19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1,
23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8,
25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19,
2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6,
2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69,
2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53,
2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97,
2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45,
387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85,
419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9,
505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14,
546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55,
1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55,
1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85,
3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01,
3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06,
4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87,
15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92,
16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06,
17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7,
213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4,
283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1,
323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8,
156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102,
200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1,
238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4,
67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2,
113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1,
126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9,
28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8,
38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8,
55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2,
50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6,
61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4,
38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4,
51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169,
57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1,
7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9,
9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191,
10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7,
13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4,
13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074,
15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490,
23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8,
37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6,
47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8,
190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076,
253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3,
297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884,
3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902,
5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488,
7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df",
"tbl", "data.frame"))
set.seed(1234) ## to make the results reproducible
## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set
## see https://github.com/tidymodels/rsample/issues/158
indices <-
list(analysis = seq(nrow(df_ini)-1),
assessment = nrow(df_ini)
)
df_split <- make_splits(indices, df_ini)
## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
glmnet_recipe <-
recipe(formula = berd ~ ., data = df_train) |>
update_role(year, new_role = "ID") |>
step_zv(all_predictors()) |>
step_normalize(all_predictors(), -all_nominal())
glmnet_spec <-
linear_reg(penalty = tune(), mixture = tune()) |>
set_mode("regression") |>
set_engine("glmnet")
glmnet_workflow <-
workflow() |>
add_recipe(glmnet_recipe) |>
add_model(glmnet_spec)
glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05,
0.2, 0.4, 0.6, 0.8, 1))
glmnet_tune <-
tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) )
print(collect_metrics(glmnet_tune))
#> # A tibble: 240 × 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Preprocessor1_Mo…
#> 2 0.000001 0.05 rsq standard 0.929 3 0.0420 Preprocessor1_Mo…
#> 3 0.00000183 0.05 rmse standard 375. 3 48.9 Preprocessor1_Mo…
#> 4 0.00000183 0.05 rsq standard 0.929 3 0.0420 Preprocessor1_Mo…
#> 5 0.00000336 0.05 rmse standard 375. 3 48.9 Preprocessor1_Mo…
#> 6 0.00000336 0.05 rsq standard 0.929 3 0.0420 Preprocessor1_Mo…
#> 7 0.00000616 0.05 rmse standard 375. 3 48.9 Preprocessor1_Mo…
#> 8 0.00000616 0.05 rsq standard 0.929 3 0.0420 Preprocessor1_Mo…
#> 9 0.0000113 0.05 rmse standard 375. 3 48.9 Preprocessor1_Mo…
#> 10 0.0000113 0.05 rsq standard 0.929 3 0.0420 Preprocessor1_Mo…
#> # … with 230 more rows
print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 × 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Preprocessor1_Model…
#> 2 0.00000183 0.05 rmse standard 375. 3 48.9 Preprocessor1_Model…
#> 3 0.00000336 0.05 rmse standard 375. 3 48.9 Preprocessor1_Model…
#> 4 0.00000616 0.05 rmse standard 375. 3 48.9 Preprocessor1_Model…
#> 5 0.0000113 0.05 rmse standard 375. 3 48.9 Preprocessor1_Model…
best_net <- select_best(glmnet_tune, "rmse")
final_net <- finalize_workflow(
glmnet_workflow,
best_net
)
final_res_net <- last_fit(final_net, df_split)
#> ! train/test split: internal: A correlation computation is required, but the inputs are size zero or o...
print(final_res_net)
#> # Resampling results
#> # Manual resampling
#> # A tibble: 1 × 6
#> splits id .metrics .notes .predictions .workflow
#> <list> <chr> <list> <list> <list> <list>
#> 1 <split [16/1]> train/test split <tibble> <tibble> <tibble [1 × 4]> <workflow>
#>
#> There were issues with some computations:
#>
#> - Warning(s) x1: A correlation computation is required, but the inputs are size ze...
#>
#> Run `show_notes(.Last.tune.result)` for more information.
final_fit <- final_res_net %>%
collect_predictions()
show_notes(.Last.tune.result)
#> unique notes:
#> ────────────────────────────────────────────────────────────────────────────────
#> A correlation computation is required, but the inputs are size zero or one and the standard deviation cannot be computed. `NA` will be returned.
sessionInfo()
#> R version 4.2.1 (2022-06-23)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Debian GNU/Linux 11 (bullseye)
#>
#> Matrix products: default
#> BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
#>
#> locale:
#> [1] LC_CTYPE=en_GB.UTF-8 LC_NUMERIC=C
#> [3] LC_TIME=en_GB.UTF-8 LC_COLLATE=en_GB.UTF-8
#> [5] LC_MONETARY=en_GB.UTF-8 LC_MESSAGES=en_GB.UTF-8
#> [7] LC_PAPER=en_GB.UTF-8 LC_NAME=C
#> [9] LC_ADDRESS=C LC_TELEPHONE=C
#> [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] glmnet_4.1-4 Matrix_1.4-1 yardstick_1.0.0 workflowsets_1.0.0
#> [5] workflows_1.0.0 tune_1.0.0 tidyr_1.2.0 tibble_3.1.7
#> [9] rsample_1.0.0 recipes_1.0.1 purrr_0.3.4 parsnip_1.0.0
#> [13] modeldata_1.0.0 infer_1.0.2 ggplot2_3.3.6 dplyr_1.0.9
#> [17] dials_1.0.0 scales_1.2.0 broom_1.0.0 tidymodels_1.0.0
#>
#> loaded via a namespace (and not attached):
#> [1] splines_4.2.1 foreach_1.5.2 prodlim_2019.11.13 assertthat_0.2.1
#> [5] conflicted_1.1.0 highr_0.9 GPfit_1.0-8 yaml_2.3.5
#> [9] globals_0.15.1 ipred_0.9-13 pillar_1.7.0 backports_1.4.1
#> [13] lattice_0.20-45 glue_1.6.2 digest_0.6.29 hardhat_1.2.0
#> [17] colorspace_2.0-3 htmltools_0.5.2 timeDate_3043.102 pkgconfig_2.0.3
#> [21] lhs_1.1.5 DiceDesign_1.9 listenv_0.8.0 gower_1.0.0
#> [25] lava_1.6.10 generics_0.1.2 ellipsis_0.3.2 cachem_1.0.6
#> [29] withr_2.5.0 furrr_0.3.0 nnet_7.3-17 cli_3.3.0
#> [33] survival_3.3-1 magrittr_2.0.3 crayon_1.5.1 memoise_2.0.1
#> [37] evaluate_0.15 fs_1.5.2 future_1.26.1 fansi_1.0.3
#> [41] parallelly_1.32.0 MASS_7.3-57 class_7.3-20 tools_4.2.1
#> [45] lifecycle_1.0.1 stringr_1.4.0 munsell_0.5.0 reprex_2.0.1
#> [49] compiler_4.2.1 rlang_1.0.4 grid_4.2.1 iterators_1.0.14
#> [53] rmarkdown_2.14 gtable_0.3.0 codetools_0.2-18 DBI_1.1.3
#> [57] R6_2.5.1 lubridate_1.8.0 knitr_1.39 fastmap_1.1.0
#> [61] future.apply_1.9.0 utf8_1.2.2 shape_1.4.6 stringi_1.7.6
#> [65] parallel_4.2.1 Rcpp_1.0.8.3 vctrs_0.4.1 rpart_4.1.16
#> [69] tidyselect_1.1.2 xfun_0.31
Created on 2022-08-06 by the reprex package (v2.0.1)
Metadata
Metadata
Assignees
Labels
No labels