Skip to content
This repository was archived by the owner on Jun 14, 2023. It is now read-only.
This repository was archived by the owner on Jun 14, 2023. It is now read-only.

Again Error for Unusual Split of Data into Test and Train Sets #260

@larry77

Description

@larry77

Dear All,
I am going back to tidymodels and I dusted off the example discussed at https://github.com/tidymodels/tidymodels.org/issues/198 .
Please have a look at the reprex.

Now I get a new error/warning (and I am sure I did not have this specific issue in the past) which seems to be due to the train data consisting only of one data point, but I am not 100% sure of this nor do I know what to do if this is the case (I really need this unusual data split).
I have two questions

  1. can anyone tell me what to do to fix my code?
  2. given that my data has a strong time component (in df_ini there is a "year" column), I think I should use a different resampling technique (see https://www.tmwr.org/resampling.html#rolling ).
    It must be a one-liner, but I am experiencing some issues. Can anyone show me how to implement in my example (once it has been fixed) the rolling forecasting origin resampling, e.g. non cumulative, analysis size of eight samples (eight years) and an assessment set size of two (two years)?

Thanks a lot!

library(tidymodels)

tidymodels_prefer() 

df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008, 
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018), 
    capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3, 
    3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 
    3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9, 
    7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 
    11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 
    17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605, 
    19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
    20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
    29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53, 
    2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 
    2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
    ), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57, 
    2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 
    2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93, 
    389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 
    392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
    ), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63, 
    515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 
    524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59, 
    1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 
    1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 
    2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75, 
    3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 
    3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 
    4171.72), employment_total_lag_1 = c(14509.58, 15127.99, 
    15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 
    16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 
    17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7, 
    220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9, 
    288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 
    344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8, 
    169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
    213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
    249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6, 
    71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
    91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
    129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4, 
    28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
    39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
    59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7, 
    2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1, 
    48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 
    56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 
    71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4, 
    42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 
    51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
    ), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2, 
    8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 
    10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6, 
    9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 
    12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 
    13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6, 
    16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
    24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
    ), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6, 
    38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 
    49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4, 
    197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6, 
    262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 
    307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2, 
    3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 
    3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
    ), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3, 
    8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
    13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5, 
    19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 
    23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 
    25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19, 
    2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 
    2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 
    2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53, 
    2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 
    2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
    ), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45, 
    387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 
    419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9, 
    505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14, 
    546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
    ), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55, 
    1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55, 
    1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85, 
    3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 
    3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 
    4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87, 
    15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 
    16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 
    17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7, 
    213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4, 
    283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 
    323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8, 
    156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102, 
    200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 
    238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4, 
    67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 
    113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
    126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9, 
    28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 
    38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
    55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2, 
    50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 
    61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4, 
    38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4, 
    51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
    57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1, 
    7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 
    9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191, 
    10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 
    13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 
    13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074, 
    15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 
    23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
    ), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8, 
    37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 
    47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8, 
    190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076, 
    253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
    297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884, 
    3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902, 
    5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488, 
    7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df", 
"tbl", "data.frame"))





set.seed(1234)  ## to make the results reproducible






## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set

## see https://github.com/tidymodels/rsample/issues/158


indices <-
  list(analysis   = seq(nrow(df_ini)-1), 
       assessment = nrow(df_ini)
       )

df_split <- make_splits(indices, df_ini)


## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works

df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)



glmnet_recipe <- 
    recipe(formula = berd ~ ., data = df_train) |> 
    update_role(year, new_role = "ID") |> 
  step_zv(all_predictors()) |> 
  step_normalize(all_predictors(), -all_nominal()) 

glmnet_spec <- 
  linear_reg(penalty = tune(), mixture = tune()) |> 
  set_mode("regression") |> 
  set_engine("glmnet") 

glmnet_workflow <- 
  workflow() |> 
  add_recipe(glmnet_recipe) |> 
  add_model(glmnet_spec) 




glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05, 
    0.2, 0.4, 0.6, 0.8, 1)) 

glmnet_tune <- 
  tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) ) 

print(collect_metrics(glmnet_tune))
#> # A tibble: 240 × 8
#>       penalty mixture .metric .estimator    mean     n std_err .config          
#>         <dbl>   <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>            
#>  1 0.000001      0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  2 0.000001      0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  3 0.00000183    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  4 0.00000183    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  5 0.00000336    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  6 0.00000336    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  7 0.00000616    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  8 0.00000616    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  9 0.0000113     0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#> 10 0.0000113     0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#> # … with 230 more rows

print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 × 8
#>      penalty mixture .metric .estimator  mean     n std_err .config             
#>        <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 0.000001      0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 2 0.00000183    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 3 0.00000336    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 4 0.00000616    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 5 0.0000113     0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…

best_net <- select_best(glmnet_tune, "rmse")


final_net <- finalize_workflow(
  glmnet_workflow,
  best_net
)


final_res_net <- last_fit(final_net, df_split)
#> ! train/test split: internal: A correlation computation is required, but the inputs are size zero or o...


print(final_res_net)
#> # Resampling results
#> # Manual resampling 
#> # A tibble: 1 × 6
#>   splits         id               .metrics .notes   .predictions     .workflow 
#>   <list>         <chr>            <list>   <list>   <list>           <list>    
#> 1 <split [16/1]> train/test split <tibble> <tibble> <tibble [1 × 4]> <workflow>
#> 
#> There were issues with some computations:
#> 
#>   - Warning(s) x1: A correlation computation is required, but the inputs are size ze...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

final_fit <- final_res_net %>%
    collect_predictions()



show_notes(.Last.tune.result)
#> unique notes:
#> ────────────────────────────────────────────────────────────────────────────────
#> A correlation computation is required, but the inputs are size zero or one and the standard deviation cannot be computed. `NA` will be returned.


sessionInfo()
#> R version 4.2.1 (2022-06-23)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Debian GNU/Linux 11 (bullseye)
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
#>  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
#>  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] glmnet_4.1-4       Matrix_1.4-1       yardstick_1.0.0    workflowsets_1.0.0
#>  [5] workflows_1.0.0    tune_1.0.0         tidyr_1.2.0        tibble_3.1.7      
#>  [9] rsample_1.0.0      recipes_1.0.1      purrr_0.3.4        parsnip_1.0.0     
#> [13] modeldata_1.0.0    infer_1.0.2        ggplot2_3.3.6      dplyr_1.0.9       
#> [17] dials_1.0.0        scales_1.2.0       broom_1.0.0        tidymodels_1.0.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] splines_4.2.1      foreach_1.5.2      prodlim_2019.11.13 assertthat_0.2.1  
#>  [5] conflicted_1.1.0   highr_0.9          GPfit_1.0-8        yaml_2.3.5        
#>  [9] globals_0.15.1     ipred_0.9-13       pillar_1.7.0       backports_1.4.1   
#> [13] lattice_0.20-45    glue_1.6.2         digest_0.6.29      hardhat_1.2.0     
#> [17] colorspace_2.0-3   htmltools_0.5.2    timeDate_3043.102  pkgconfig_2.0.3   
#> [21] lhs_1.1.5          DiceDesign_1.9     listenv_0.8.0      gower_1.0.0       
#> [25] lava_1.6.10        generics_0.1.2     ellipsis_0.3.2     cachem_1.0.6      
#> [29] withr_2.5.0        furrr_0.3.0        nnet_7.3-17        cli_3.3.0         
#> [33] survival_3.3-1     magrittr_2.0.3     crayon_1.5.1       memoise_2.0.1     
#> [37] evaluate_0.15      fs_1.5.2           future_1.26.1      fansi_1.0.3       
#> [41] parallelly_1.32.0  MASS_7.3-57        class_7.3-20       tools_4.2.1       
#> [45] lifecycle_1.0.1    stringr_1.4.0      munsell_0.5.0      reprex_2.0.1      
#> [49] compiler_4.2.1     rlang_1.0.4        grid_4.2.1         iterators_1.0.14  
#> [53] rmarkdown_2.14     gtable_0.3.0       codetools_0.2-18   DBI_1.1.3         
#> [57] R6_2.5.1           lubridate_1.8.0    knitr_1.39         fastmap_1.1.0     
#> [61] future.apply_1.9.0 utf8_1.2.2         shape_1.4.6        stringi_1.7.6     
#> [65] parallel_4.2.1     Rcpp_1.0.8.3       vctrs_0.4.1        rpart_4.1.16      
#> [69] tidyselect_1.1.2   xfun_0.31

Created on 2022-08-06 by the reprex package (v2.0.1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions