|
| 1 | +--- |
| 2 | +title: Tracking Extra Quantities |
| 3 | +engine: julia |
| 4 | +aliases: |
| 5 | + - ../../tutorials/usage-generated-quantities/index.html |
| 6 | + - ../generated-quantities/index.html |
| 7 | +--- |
| 8 | + |
| 9 | +```{julia} |
| 10 | +#| echo: false |
| 11 | +#| output: false |
| 12 | +using Pkg; |
| 13 | +Pkg.instantiate(); |
| 14 | +``` |
| 15 | + |
| 16 | +Often, the most natural parameterization for a model is not the most computationally feasible. |
| 17 | +Consider the following (efficiently reparametrized) implementation of Neal's funnel [(Neal, 2003)](https://arxiv.org/abs/physics/0009028): |
| 18 | + |
| 19 | +```{julia} |
| 20 | +using Turing |
| 21 | +
|
| 22 | +@model function Neal() |
| 23 | + # Raw draws |
| 24 | + y_raw ~ Normal(0, 1) |
| 25 | + x_raw ~ arraydist([Normal(0, 1) for i in 1:9]) |
| 26 | +
|
| 27 | + # Transform: |
| 28 | + y = 3 * y_raw |
| 29 | + x = exp.(y ./ 2) .* x_raw |
| 30 | + return nothing |
| 31 | +end |
| 32 | +``` |
| 33 | + |
| 34 | +In this case, the random variables exposed in the chain (`x_raw`, `y_raw`) are not in a helpful form — what we're after are the deterministically transformed variables `x` and `y`. |
| 35 | + |
| 36 | +More generally, there are often quantities in our models that we might be interested in viewing, but which are not explicitly present in our chain. |
| 37 | + |
| 38 | +There are two ways of tracking such extra quantities. |
| 39 | + |
| 40 | +## Using `:=` (during inference) |
| 41 | + |
| 42 | +The first way is to use the `:=` operator, which behaves exactly like `=` except that the values of the variables on its left-hand side are automatically added to the chain returned by the sampler. |
| 43 | +For example: |
| 44 | + |
| 45 | +```{julia} |
| 46 | +@model function Neal_coloneq() |
| 47 | + # Raw draws |
| 48 | + y_raw ~ Normal(0, 1) |
| 49 | + x_raw ~ arraydist([Normal(0, 1) for i in 1:9]) |
| 50 | +
|
| 51 | + # Transform: |
| 52 | + y := 3 * y_raw |
| 53 | + x := exp.(y ./ 2) .* x_raw |
| 54 | +end |
| 55 | +
|
| 56 | +sample(Neal_coloneq(), NUTS(), 1000; progress=false) |
| 57 | +``` |
| 58 | + |
| 59 | +## Using `returned` (post-inference) |
| 60 | + |
| 61 | +Alternatively, one can specify the extra quantities as part of the model function's return statement: |
| 62 | + |
| 63 | +```{julia} |
| 64 | +@model function Neal_return() |
| 65 | + # Raw draws |
| 66 | + y_raw ~ Normal(0, 1) |
| 67 | + x_raw ~ arraydist([Normal(0, 1) for i in 1:9]) |
| 68 | +
|
| 69 | + # Transform and return as a NamedTuple |
| 70 | + y = 3 * y_raw |
| 71 | + x = exp.(y ./ 2) .* x_raw |
| 72 | + return [x; y] |
| 73 | +end |
| 74 | +
|
| 75 | +chain = sample(Neal_return(), NUTS(), 1000; progress=false) |
| 76 | +``` |
| 77 | + |
| 78 | +This chain does not contain `x` and `y`, but we can extract the values using the `returned` function. |
| 79 | +Calling this function outputs an array of values specified in the return statement of the model. |
| 80 | + |
| 81 | +```{julia} |
| 82 | +returned(Neal_return(), chain) |
| 83 | +``` |
| 84 | + |
| 85 | +Each element of this corresponds to an array with the values of `x1, x2, ..., x9, y` for each posterior sample. |
| 86 | + |
| 87 | +In this case, it might be useful to reorganize our output into a matrix for plotting: |
| 88 | + |
| 89 | +```{julia} |
| 90 | +reparam_chain = reduce(hcat, returned(Neal_return(), chain))' |
| 91 | +``` |
| 92 | + |
| 93 | +from which we can recover a vector of our samples: |
| 94 | + |
| 95 | +```{julia} |
| 96 | +x1_samples = reparam_chain[:, 1] |
| 97 | +y_samples = reparam_chain[:, 10] |
| 98 | +``` |
0 commit comments