|
4 | 4 | ## Interface
|
5 | 5 |
|
6 | 6 | - `Base.ndims(<:AbstractTransform)`: N dims of modes
|
7 |
| - - `transform(<:AbstractTransform, x::AbstractArray)`: Apply the transform to x |
8 | 7 | - `truncate_modes(<:AbstractTransform, x_transformed::AbstractArray)`: Truncate modes
|
9 | 8 | that contribute to the noise
|
10 |
| - - `inverse(<:AbstractTransform, x_transformed::AbstractArray)`: Apply the inverse |
| 9 | +
|
| 10 | +### Transform Interface |
| 11 | +
|
| 12 | + - `plan_transform(<:AbstractTransform, x::AbstractArray, prev_plan)`: Construct a plan to |
| 13 | + apply the transform to x. Might reuse the previous plan if possible |
| 14 | + - `transform(<:AbstractTransform, x::AbstractArray, plan)`: Apply the transform to x using |
| 15 | + the plan |
| 16 | +
|
| 17 | +### Inverse Transform Interface |
| 18 | +
|
| 19 | + - `plan_inverse(<:AbstractTransform, x_transformed::AbstractArray, prev_plan, M)`: |
| 20 | + Construct a plan to apply the inverse transform to `x_transformed`. Might reuse the |
| 21 | + previous plan if possible |
| 22 | + - `inverse(<:AbstractTransform, x_transformed::AbstractArray, plan, M)`: Apply the inverse |
11 | 23 | transform to `x_transformed`
|
12 | 24 | """
|
13 | 25 | abstract type AbstractTransform{T} end
|
|
22 | 34 |
|
23 | 35 | Base.ndims(T::FourierTransform) = length(T.modes)
|
24 | 36 |
|
25 |
| -transform(ft::FourierTransform, x::AbstractArray) = rfft(x, 1:ndims(ft)) |
| 37 | +function plan_transform(ft::FourierTransform, x::AbstractArray, ::Nothing) |
| 38 | + return plan_rfft(x, 1:ndims(ft)) |
| 39 | +end |
| 40 | + |
| 41 | +function plan_transform(ft::FourierTransform, x::AbstractArray, prev_plan) |
| 42 | + size(prev_plan) == size(x) && eltype(prev_plan) == eltype(x) && return prev_plan |
| 43 | + return plan_transform(ft, x, nothing) |
| 44 | +end |
| 45 | + |
| 46 | +@non_differentiable plan_transform(::Any...) |
| 47 | + |
| 48 | +transform(::FourierTransform, x::AbstractArray, plan) = plan * x |
26 | 49 |
|
27 | 50 | function low_pass(ft::FourierTransform, x_fft::AbstractArray)
|
28 | 51 | return view(x_fft, map(d -> 1:d, ft.modes)..., :, :)
|
29 | 52 | end
|
30 | 53 |
|
31 | 54 | truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft)
|
32 | 55 |
|
33 |
| -function inverse( |
34 |
| - ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N} |
35 |
| - return real(irfft(x_fft, first(M), 1:ndims(ft))) |
| 56 | +function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N}, |
| 57 | + ::Nothing, M::NTuple{N, Int64}) where {T, N} |
| 58 | + return plan_irfft(x_transformed, first(M), 1:ndims(ft)) |
| 59 | +end |
| 60 | + |
| 61 | +function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N}, |
| 62 | + prev_plan, M::NTuple{N, Int64}) where {T, N} |
| 63 | + size(prev_plan) == size(x_transformed) && eltype(prev_plan) == eltype(x_transformed) && |
| 64 | + return prev_plan |
| 65 | + return plan_inverse(ft, x_transformed, nothing, M) |
| 66 | +end |
| 67 | + |
| 68 | +@non_differentiable plan_inverse(::Any...) |
| 69 | + |
| 70 | +function inverse(::FourierTransform, x_transformed::AbstractArray{T, N}, plan, |
| 71 | + ::NTuple{N, Int64}) where {T, N} |
| 72 | + return real(plan * x_transformed) |
36 | 73 | end
|
0 commit comments