Skip to content

Commit 96d7b5d

Browse files
committed
feat: cache plans for fft
1 parent ca41edd commit 96d7b5d

File tree

3 files changed

+61
-13
lines changed

3 files changed

+61
-13
lines changed

src/NeuralOperators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module NeuralOperators
33
using ArgCheck: @argcheck
44
using ChainRulesCore: @non_differentiable
55
using ConcreteStructs: @concrete
6-
using FFTW: FFTW, irfft, rfft
6+
using FFTW: FFTW, plan_rfft, plan_irfft
77
using Random: Random, AbstractRNG
88
using Static: StaticBool, False, True, known, static
99

src/layers.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ function LuxCore.initialparameters(rng::AbstractRNG, layer::OperatorConv)
4848
rng, eltype(layer.tform), out_chs, in_chs, layer.prod_modes))
4949
end
5050

51+
function LuxCore.initialstates(::AbstractRNG, layer::OperatorConv)
52+
fake_x = zeros(Float32, ntuple(Returns(1), ndims(layer.tform))..., 1)
53+
plan_tform = plan_transform(layer.tform, fake_x, nothing)
54+
x = transform(layer.tform, fake_x, plan_tform)
55+
plan_inv_tform = plan_inverse(layer.tform, x, nothing, size(x))
56+
return (; plan_tform, plan_inv_tform)
57+
end
58+
5159
function LuxCore.parameterlength(layer::OperatorConv)
5260
return layer.prod_modes * layer.in_chs * layer.out_chs
5361
end
@@ -59,27 +67,30 @@ function OperatorConv(
5967
end
6068

6169
function (conv::OperatorConv{True})(x::AbstractArray, ps, st)
62-
return operator_conv(x, conv.tform, ps.weight), st
70+
return operator_conv(x, conv.tform, ps.weight, st)
6371
end
6472

6573
function (conv::OperatorConv{False})(x::AbstractArray, ps, st)
6674
N = ndims(conv.tform)
6775
xᵀ = permutedims(x, (ntuple(i -> i + 1, N)..., 1, N + 2))
68-
yᵀ = operator_conv(xᵀ, conv.tform, ps.weight)
76+
yᵀ, stₙ = operator_conv(xᵀ, conv.tform, ps.weight, st)
6977
y = permutedims(yᵀ, (N + 1, 1:N..., N + 2))
70-
return y, st
78+
return y, stₙ
7179
end
7280

73-
function operator_conv(x, tform::AbstractTransform, weights)
74-
x_t = transform(tform, x)
81+
function operator_conv(x, tform::AbstractTransform, weights, st)
82+
plan_tform = plan_transform(tform, x, st.plan_tform)
83+
x_t = transform(tform, x, plan_tform)
84+
7585
x_tr = truncate_modes(tform, x_t)
7686
x_p = apply_pattern(x_tr, weights)
7787

7888
pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)]
7989
x_padded = NNlib.pad_constant(x_p, expand_pad_dims(pad_dims), false;
8090
dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p)
8191

82-
return inverse(tform, x_padded, size(x))
92+
plan_inv_tform = plan_inverse(tform, x_padded, st.plan_inv_tform, size(x))
93+
return inverse(tform, x_padded, plan_inv_tform, size(x)), (; plan_tform, plan_inv_tform)
8394
end
8495

8596
"""

src/transform.jl

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,22 @@
44
## Interface
55
66
- `Base.ndims(<:AbstractTransform)`: N dims of modes
7-
- `transform(<:AbstractTransform, x::AbstractArray)`: Apply the transform to x
87
- `truncate_modes(<:AbstractTransform, x_transformed::AbstractArray)`: Truncate modes
98
that contribute to the noise
10-
- `inverse(<:AbstractTransform, x_transformed::AbstractArray)`: Apply the inverse
9+
10+
### Transform Interface
11+
12+
- `plan_transform(<:AbstractTransform, x::AbstractArray, prev_plan)`: Construct a plan to
13+
apply the transform to x. Might reuse the previous plan if possible
14+
- `transform(<:AbstractTransform, x::AbstractArray, plan)`: Apply the transform to x using
15+
the plan
16+
17+
### Inverse Transform Interface
18+
19+
- `plan_inverse(<:AbstractTransform, x_transformed::AbstractArray, prev_plan, M)`:
20+
Construct a plan to apply the inverse transform to `x_transformed`. Might reuse the
21+
previous plan if possible
22+
- `inverse(<:AbstractTransform, x_transformed::AbstractArray, plan, M)`: Apply the inverse
1123
transform to `x_transformed`
1224
"""
1325
abstract type AbstractTransform{T} end
@@ -22,15 +34,40 @@ end
2234

2335
Base.ndims(T::FourierTransform) = length(T.modes)
2436

25-
transform(ft::FourierTransform, x::AbstractArray) = rfft(x, 1:ndims(ft))
37+
function plan_transform(ft::FourierTransform, x::AbstractArray, ::Nothing)
38+
return plan_rfft(x, 1:ndims(ft))
39+
end
40+
41+
function plan_transform(ft::FourierTransform, x::AbstractArray, prev_plan)
42+
size(prev_plan) == size(x) && eltype(prev_plan) == eltype(x) && return prev_plan
43+
return plan_transform(ft, x, nothing)
44+
end
45+
46+
@non_differentiable plan_transform(::Any...)
47+
48+
transform(::FourierTransform, x::AbstractArray, plan) = plan * x
2649

2750
function low_pass(ft::FourierTransform, x_fft::AbstractArray)
2851
return view(x_fft, map(d -> 1:d, ft.modes)..., :, :)
2952
end
3053

3154
truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft)
3255

33-
function inverse(
34-
ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N}
35-
return real(irfft(x_fft, first(M), 1:ndims(ft)))
56+
function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N},
57+
::Nothing, M::NTuple{N, Int64}) where {T, N}
58+
return plan_irfft(x_transformed, first(M), 1:ndims(ft))
59+
end
60+
61+
function plan_inverse(ft::FourierTransform, x_transformed::AbstractArray{T, N},
62+
prev_plan, M::NTuple{N, Int64}) where {T, N}
63+
size(prev_plan) == size(x_transformed) && eltype(prev_plan) == eltype(x_transformed) &&
64+
return prev_plan
65+
return plan_inverse(ft, x_transformed, nothing, M)
66+
end
67+
68+
@non_differentiable plan_inverse(::Any...)
69+
70+
function inverse(::FourierTransform, x_transformed::AbstractArray{T, N}, plan,
71+
::NTuple{N, Int64}) where {T, N}
72+
return real(plan * x_transformed)
3673
end

0 commit comments

Comments
 (0)