Skip to content

Commit 9de027b

Browse files
committed
perf: use the permuted formulation
1 parent 188707c commit 9de027b

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

bench/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
34
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
45
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
56
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

bench/lux.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ThreadPinning
22
pinthreads(:cores)
33
threadinfo()
44

5-
using BenchmarkTools, NeuralOperators, Random, Optimisers, Zygote
5+
using BenchmarkTools, NeuralOperators, Random, Optimisers, Zygote, Lux
66

77
rng = Xoshiro(1234)
88

@@ -11,9 +11,9 @@ train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...)
1111
function train!(loss, backend, model, ps, st, data; epochs=10)
1212
l1 = loss(model, ps, st, first(data))
1313

14-
tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0))
14+
tstate = Training.TrainState(model, ps, st, Adam(0.01f0))
1515
for _ in 1:epochs, (x, y) in data
16-
_, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate)
16+
_, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate)
1717
end
1818

1919
l2 = loss(model, ps, st, first(data))
@@ -25,14 +25,14 @@ end
2525
n_points = 128
2626
batch_size = 64
2727

28-
x = rand(Float32, 1, n_points, batch_size);
29-
y = rand(Float32, 1, n_points, batch_size);
28+
x = rand(Float32, n_points, 1, batch_size);
29+
y = rand(Float32, n_points, 1, batch_size);
3030
data = [(x, y)];
3131
t_fwd = zeros(5)
3232
t_train = zeros(5)
3333
for i in 1:5
3434
chs = (1, 128, fill(64, i)..., 128, 1)
35-
model = FourierNeuralOperator(gelu; chs=chs, modes=(16,))
35+
model = FourierNeuralOperator(gelu; chs, modes=(16,), permuted=Val(true))
3636
ps, st = Lux.setup(rng, model)
3737
model(x, ps, st) # TTFX
3838

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
55
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
66
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
77
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
8+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
9+
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
810
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
911
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1012
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

0 commit comments

Comments
 (0)