Skip to content

Commit da90555

Browse files
committed
feat: add step_/layer_ epi_YeoJohnson
1 parent 5721d41 commit da90555

File tree

8 files changed

+981
-4
lines changed

8 files changed

+981
-4
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Imports:
4242
recipes (>= 1.0.4),
4343
rlang (>= 1.1.0),
4444
stats,
45+
stringr,
4546
tibble,
4647
tidyr,
4748
tidyselect,

NAMESPACE

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ S3method(bake,check_enough_data)
1919
S3method(bake,epi_recipe)
2020
S3method(bake,step_adjust_latency)
2121
S3method(bake,step_climate)
22+
S3method(bake,step_epi_YeoJohnson)
2223
S3method(bake,step_epi_ahead)
2324
S3method(bake,step_epi_lag)
2425
S3method(bake,step_epi_slide)
@@ -53,6 +54,7 @@ S3method(prep,check_enough_data)
5354
S3method(prep,epi_recipe)
5455
S3method(prep,step_adjust_latency)
5556
S3method(prep,step_climate)
57+
S3method(prep,step_epi_YeoJohnson)
5658
S3method(prep,step_epi_ahead)
5759
S3method(prep,step_epi_lag)
5860
S3method(prep,step_epi_slide)
@@ -74,6 +76,7 @@ S3method(print,flatline)
7476
S3method(print,frosting)
7577
S3method(print,layer_add_forecast_date)
7678
S3method(print,layer_add_target_date)
79+
S3method(print,layer_epi_YeoJohnson)
7780
S3method(print,layer_naomit)
7881
S3method(print,layer_point_from_distn)
7982
S3method(print,layer_population_scaling)
@@ -84,6 +87,7 @@ S3method(print,layer_threshold)
8487
S3method(print,layer_unnest)
8588
S3method(print,step_adjust_latency)
8689
S3method(print,step_climate)
90+
S3method(print,step_epi_YeoJohnson)
8791
S3method(print,step_epi_ahead)
8892
S3method(print,step_epi_lag)
8993
S3method(print,step_epi_slide)
@@ -99,6 +103,7 @@ S3method(run_mold,default_epi_recipe_blueprint)
99103
S3method(slather,layer_add_forecast_date)
100104
S3method(slather,layer_add_target_date)
101105
S3method(slather,layer_cdc_flatline_quantiles)
106+
S3method(slather,layer_epi_YeoJohnson)
102107
S3method(slather,layer_naomit)
103108
S3method(slather,layer_point_from_distn)
104109
S3method(slather,layer_population_scaling)
@@ -112,6 +117,7 @@ S3method(snap,quantile_pred)
112117
S3method(tidy,check_enough_data)
113118
S3method(tidy,frosting)
114119
S3method(tidy,layer)
120+
S3method(tidy,step_epi_YeoJohnson)
115121
S3method(update,layer)
116122
S3method(vec_arith,quantile_pred)
117123
S3method(vec_arith.numeric,quantile_pred)
@@ -174,6 +180,7 @@ export(layer)
174180
export(layer_add_forecast_date)
175181
export(layer_add_target_date)
176182
export(layer_cdc_flatline_quantiles)
183+
export(layer_epi_YeoJohnson)
177184
export(layer_naomit)
178185
export(layer_point_from_distn)
179186
export(layer_population_scaling)
@@ -205,6 +212,7 @@ export(smooth_quantile_reg)
205212
export(snap)
206213
export(step_adjust_latency)
207214
export(step_climate)
215+
export(step_epi_YeoJohnson)
208216
export(step_epi_ahead)
209217
export(step_epi_lag)
210218
export(step_epi_naomit)

R/layer_yeo_johnson.R

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#' Unormalizing transformation
2+
#'
3+
#' Will undo a step_epi_YeoJohnson transformation.
4+
#'
5+
#' @param frosting a `frosting` postprocessor. The layer will be added to the
6+
#' sequence of operations for this frosting.
7+
#' @param lambdas Internal. A data frame of lambda values to be used for
8+
#' inverting the transformation.
9+
#' @param ... One or more selector functions to scale variables
10+
#' for this step. See [recipes::selections()] for more details.
11+
#' @param by A (possibly named) character vector of variables to join by.
12+
#' @param id a random id string
13+
#'
14+
#' @return an updated `frosting` postprocessor
15+
#' @export
16+
#' @examples
17+
#' library(dplyr)
18+
#' jhu <- epidatasets::cases_deaths_subset %>%
19+
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
20+
#' select(geo_value, time_value, cases)
21+
#'
22+
#' # Create a recipe with a Yeo-Johnson transformation.
23+
#' r <- epi_recipe(jhu) %>%
24+
#' step_epi_YeoJohnson(cases) %>%
25+
#' step_epi_lag(cases, lag = 0) %>%
26+
#' step_epi_ahead(cases, ahead = 0, role = "outcome") %>%
27+
#' step_epi_naomit()
28+
#'
29+
#' # Create a frosting layer that will undo the Yeo-Johnson transformation.
30+
#' f <- frosting() %>%
31+
#' layer_predict() %>%
32+
#' layer_epi_YeoJohnson(.pred)
33+
#'
34+
#' # Create a workflow and fit it.
35+
#' wf <- epi_workflow(r, linear_reg()) %>%
36+
#' fit(jhu) %>%
37+
#' add_frosting(f)
38+
#'
39+
#' # Forecast the workflow, which should reverse the Yeo-Johnson transformation.
40+
#' forecast(wf)
41+
#' # Compare to the original data.
42+
#' jhu %>% filter(time_value == "2021-12-31")
43+
#' forecast(wf)
44+
layer_epi_YeoJohnson <- function(frosting, ..., lambdas = NULL, by = NULL, id = rand_id("epi_YeoJohnson")) {
45+
checkmate::assert_tibble(lambdas, min.rows = 1, null.ok = TRUE)
46+
47+
add_layer(
48+
frosting,
49+
layer_epi_YeoJohnson_new(
50+
lambdas = lambdas,
51+
by = by,
52+
terms = dplyr::enquos(...),
53+
id = id
54+
)
55+
)
56+
}
57+
58+
layer_epi_YeoJohnson_new <- function(lambdas, by, terms, id) {
59+
layer("epi_YeoJohnson", lambdas = lambdas, by = by, terms = terms, id = id)
60+
}
61+
62+
#' @export
63+
#' @importFrom workflows extract_preprocessor
64+
slather.layer_epi_YeoJohnson <- function(object, components, workflow, new_data, ...) {
65+
rlang::check_dots_empty()
66+
67+
# Get the lambdas from the layer or from the workflow.
68+
lambdas <- object$lambdas %||% get_lambdas_in_layer(workflow)
69+
70+
# If the by is not specified, try to infer it from the lambdas.
71+
if (is.null(object$by)) {
72+
# Assume `layer_predict` has calculated the prediction keys and other
73+
# layers don't change the prediction key colnames:
74+
prediction_key_colnames <- names(components$keys)
75+
lhs_potential_keys <- prediction_key_colnames
76+
rhs_potential_keys <- colnames(select(lambdas, -starts_with("lambda_")))
77+
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
78+
suggested_min_keys <- setdiff(lhs_potential_keys, "time_value")
79+
if (!all(suggested_min_keys %in% object$by)) {
80+
cli_warn(
81+
c(
82+
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
83+
but {?wasn't/weren't} found in the population `df`.",
84+
"i" = "Defaulting to join by {object$by}",
85+
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
86+
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
87+
">" = "Manually specify `by =` to silence"
88+
),
89+
class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys"
90+
)
91+
}
92+
}
93+
94+
# Establish the join columns.
95+
object$by <- object$by %||%
96+
intersect(
97+
epi_keys_only(components$predictions),
98+
colnames(select(lambdas, -starts_with(".lambda_")))
99+
)
100+
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
101+
hardhat::validate_column_names(components$predictions, joinby$x)
102+
hardhat::validate_column_names(lambdas, joinby$y)
103+
104+
# Join the lambdas.
105+
components$predictions <- inner_join(
106+
components$predictions,
107+
lambdas,
108+
by = object$by,
109+
relationship = "many-to-one",
110+
unmatched = c("error", "drop")
111+
)
112+
113+
exprs <- rlang::expr(c(!!!object$terms))
114+
pos <- tidyselect::eval_select(exprs, components$predictions)
115+
col_names <- names(pos)
116+
117+
# The `object$terms` is where the user specifies the columns they want to
118+
# untransform. We need to match the outcomes with their lambda columns in our
119+
# parameter table and then apply the inverse transformation.
120+
if (identical(col_names, ".pred")) {
121+
# In this case, we don't get a hint for the outcome column name, so we need
122+
# to infer it from the mold.
123+
if (length(components$mold$outcomes) > 1) {
124+
cli_abort("Only one outcome is allowed when specifying `.pred`.", call = rlang::caller_env())
125+
}
126+
# `outcomes` is a vector of objects like ahead_1_cases, ahead_7_cases, etc.
127+
# We want to extract the cases part.
128+
outcome_cols <- names(components$mold$outcomes) %>%
129+
stringr::str_match("ahead_\\d+_(.*)") %>%
130+
magrittr::extract(, 2)
131+
132+
components$predictions <- components$predictions %>%
133+
rowwise() %>%
134+
mutate(.pred := yj_inverse(.pred, !!sym(paste0(".lambda_", outcome_cols))))
135+
} else if (identical(col_names, character(0))) {
136+
# Wish I could suggest `all_outcomes()` here, but currently it's the same as
137+
# not specifying any terms. I don't want to spend time with dealing with
138+
# this case until someone asks for it.
139+
cli::cli_abort(
140+
"Not specifying columns to layer Yeo-Johnson is not implemented.
141+
If you had a single outcome, you can use `.pred` as a column name.
142+
If you had multiple outcomes, you'll need to specify them like
143+
`.pred_ahead_1_<outcome_col>`, `.pred_ahead_7_<outcome_col>`, etc.
144+
",
145+
call = rlang::caller_env()
146+
)
147+
} else {
148+
# In this case, we assume that the user has specified the columns they want
149+
# transformed here. We then need to determine the lambda columns for each of
150+
# these columns. That is, we need to convert a vector of column names like
151+
# c(".pred_ahead_1_case_rate", ".pred_ahead_7_case_rate") to
152+
# c("lambda_ahead_1_case_rate", "lambda_ahead_7_case_rate").
153+
original_outcome_cols <- stringr::str_match(col_names, ".pred_ahead_\\d+_(.*)")[, 2]
154+
outcomes_wout_ahead <- stringr::str_match(names(components$mold$outcomes), "ahead_\\d+_(.*)")[, 2]
155+
if (any(original_outcome_cols %nin% outcomes_wout_ahead)) {
156+
cli_abort(
157+
"All columns specified in `...` must be outcome columns.
158+
They must be of the form `.pred_ahead_1_<outcome_col>`, `.pred_ahead_7_<outcome_col>`, etc.
159+
",
160+
call = rlang::caller_env()
161+
)
162+
}
163+
164+
for (i in seq_along(col_names)) {
165+
col <- col_names[i]
166+
lambda_col <- paste0(".lambda_", original_outcome_cols[i])
167+
components$predictions <- components$predictions %>%
168+
rowwise() %>%
169+
mutate(!!sym(col) := yj_inverse(!!sym(col), !!sym(lambda_col)))
170+
}
171+
}
172+
173+
# Remove the lambda columns.
174+
components$predictions <- components$predictions %>%
175+
select(-any_of(starts_with(".lambda_"))) %>%
176+
ungroup()
177+
components
178+
}
179+
180+
#' @export
181+
print.layer_epi_YeoJohnson <- function(x, width = max(20, options()$width - 30), ...) {
182+
title <- "Yeo-Johnson transformation (see `lambdas` object for values) on "
183+
print_layer(x$terms, title = title, width = width)
184+
}
185+
186+
# Inverse Yeo-Johnson transformation
187+
#
188+
# Inverse of `yj_transform` in step_yeo_johnson.R. Note that this function is
189+
# vectorized in x, but not in lambda.
190+
yj_inverse <- function(x, lambda, eps = 0.001) {
191+
if (is.na(lambda)) {
192+
return(x)
193+
}
194+
if (!inherits(x, "tbl_df") || is.data.frame(x)) {
195+
x <- unlist(x, use.names = FALSE)
196+
} else {
197+
if (!is.vector(x)) {
198+
x <- as.vector(x)
199+
}
200+
}
201+
202+
dat_neg <- x < 0
203+
ind_neg <- list(is = which(dat_neg), not = which(!dat_neg))
204+
not_neg <- ind_neg[["not"]]
205+
is_neg <- ind_neg[["is"]]
206+
207+
nn_inv_trans <- function(x, lambda) {
208+
if (abs(lambda) < eps) {
209+
# log(x + 1)
210+
exp(x) - 1
211+
} else {
212+
# ((x + 1)^lambda - 1) / lambda
213+
(lambda * x + 1)^(1 / lambda) - 1
214+
}
215+
}
216+
217+
ng_inv_trans <- function(x, lambda) {
218+
if (abs(lambda - 2) < eps) {
219+
# -log(-x + 1)
220+
-(exp(-x) - 1)
221+
} else {
222+
# -((-x + 1)^(2 - lambda) - 1) / (2 - lambda)
223+
-(((lambda - 2) * x + 1)^(1 / (2 - lambda)) - 1)
224+
}
225+
}
226+
227+
if (length(not_neg) > 0) {
228+
x[not_neg] <- nn_inv_trans(x[not_neg], lambda)
229+
}
230+
231+
if (length(is_neg) > 0) {
232+
x[is_neg] <- ng_inv_trans(x[is_neg], lambda)
233+
}
234+
x
235+
}
236+
237+
get_lambdas_in_layer <- function(workflow) {
238+
this_recipe <- hardhat::extract_recipe(workflow)
239+
if (!(this_recipe %>% recipes::detect_step("epi_YeoJohnson"))) {
240+
cli_abort("`layer_epi_YeoJohnson` requires `step_epi_YeoJohnson` in the recipe.", call = rlang::caller_env())
241+
}
242+
for (step in this_recipe$steps) {
243+
if (inherits(step, "step_epi_YeoJohnson")) {
244+
lambdas <- step$lambdas
245+
break
246+
}
247+
}
248+
lambdas
249+
}

0 commit comments

Comments
 (0)