Skip to content

Commit dd5b486

Browse files
committed
feat: setup basic flux code
[skip ci] [skip docs]
1 parent b4db705 commit dd5b486

File tree

4 files changed

+117
-40
lines changed

4 files changed

+117
-40
lines changed

bench/comparison.md

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,49 @@
1-
## NeuralOperators.jl (Lux)
1+
# NeuralOperators.jl Benchmarks
22

3-
## FNO
3+
## Fourier Neural Operators
44

5-
| #layers | Forward | Train: 10 epochs |
6-
| --- | --- | --- |
7-
| 1 | 14.173699999999998 ms | 755.1466 ms |
8-
| 2 | 29.118399999999998 ms | 1407.2298 ms |
9-
| 3 | 37.6924 ms | 2367.5004999999996 ms |
10-
| 4 | 41.431400000000004 ms | 3035.1971 ms |
11-
| 5 | 59.305 ms | 3456.1902999999998 ms |
5+
### Lux.jl (Julia)
126

13-
## FNO (python: neuraloperator)
7+
| #layers | Forward | Train: 10 epochs |
8+
|:------- |:--------------------- |:--------------------- |
9+
| 1 | 14.173699999999998 ms | 755.1466 ms |
10+
| 2 | 29.118399999999998 ms | 1407.2298 ms |
11+
| 3 | 37.6924 ms | 2367.5004999999996 ms |
12+
| 4 | 41.431400000000004 ms | 3035.1971 ms |
13+
| 5 | 59.305 ms | 3456.1902999999998 ms |
1414

15-
| #layers | Forward | Train: 10 epochs |
16-
| --- | --- | --- |
17-
| 1 | 5.731542900000932 ms | 17.667421199992532 ms |
18-
| 2 | 7.833489999989979 ms | 25.585920999990776 ms |
19-
| 3 | 10.18306370000937 ms | 33.69801080002799 ms |
20-
| 4 | 12.33892210002523 ms | 41.98180860001594 ms |
21-
| 5 | 14.732645300013246 ms | 50.13744520000182 ms |
15+
### Flux.jl (Julia)
16+
17+
### neuraloperator (Python)
18+
19+
| #layers | Forward | Train: 10 epochs |
20+
|:------- |:---------------------- |:---------------------- |
21+
| 1 | 5.731542900000932 ms | 17.667421199992532 ms |
22+
| 2 | 7.833489999989979 ms | 25.585920999990776 ms |
23+
| 3 | 10.18306370000937 ms | 33.69801080002799 ms |
24+
| 4 | 12.33892210002523 ms | 41.98180860001594 ms |
25+
| 5 | 14.732645300013246 ms | 50.13744520000182 ms |
2226

2327
## DeepONet
2428

25-
| #layers | Forward | Train: 10 epochs |
26-
| --- | --- | --- |
27-
| 1 | 3.3952750000000003 ms | 76.604576 ms |
28-
| 2 | 4.360458 ms | 104.460251 ms |
29-
| 3 | 5.6310780000000005 ms | 149.148633 ms |
30-
| 4 | 7.199777 ms | 178.464657 ms |
31-
| 5 | 7.8226819999999995 ms | 193.760173 ms |
32-
33-
## DeepONet (python: deepxde)
34-
35-
| #layers | Forward | Train: 10 epochs |
36-
| --- | --- | --- |
37-
| 1 | 0.7689221948385239 ms | 25.76469287276268 ms |
38-
| 2 | 0.7733150571584702 ms | 32.17746138572693 ms |
39-
| 3 | 0.8474267274141312 ms | 36.93301998078823 ms |
40-
| 4 | 1.0069304704666138 ms | 45.45578710734844 ms |
41-
| 5 | 1.406572386622429 ms | 59.06449243426323 ms |
29+
### Lux.jl (Julia)
30+
31+
| #layers | Forward | Train: 10 epochs |
32+
|:------- |:--------------------- |:---------------- |
33+
| 1 | 3.3952750000000003 ms | 76.604576 ms |
34+
| 2 | 4.360458 ms | 104.460251 ms |
35+
| 3 | 5.6310780000000005 ms | 149.148633 ms |
36+
| 4 | 7.199777 ms | 178.464657 ms |
37+
| 5 | 7.8226819999999995 ms | 193.760173 ms |
38+
39+
### Flux.jl (Julia)
40+
41+
### deepxde (Python)
42+
43+
| #layers | Forward | Train: 10 epochs |
44+
|:------- |:---------------------- |:--------------------- |
45+
| 1 | 0.7689221948385239 ms | 25.76469287276268 ms |
46+
| 2 | 0.7733150571584702 ms | 32.17746138572693 ms |
47+
| 3 | 0.8474267274141312 ms | 36.93301998078823 ms |
48+
| 4 | 1.0069304704666138 ms | 45.45578710734844 ms |
49+
| 5 | 1.406572386622429 ms | 59.06449243426323 ms |

bench/flux.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using ThreadPinning
2+
pinthreads(:cores)
3+
threadinfo()
4+
5+
using BenchmarkTools, NeuralOperators, Random, Optimisers, Zygote
6+
7+
# TODO: Add training code
8+
9+
# FNO
10+
n_points = 128
11+
batch_size = 64
12+
13+
x = rand(Float32, 1, n_points, batch_size);
14+
y = rand(Float32, 1, n_points, batch_size);
15+
data = [(x, y)];
16+
t_fwd = zeros(5)
17+
t_train = zeros(5)
18+
19+
for i in 1:5
20+
chs = (1, 128, fill(64, i)..., 128, 1)
21+
model = FourierNeuralOperator(; ch=chs, modes=(16,), σ=gelu)
22+
# model(x) # TTFX
23+
24+
# t_fwd[i] = @belapsed $model($x)
25+
26+
# t_train[i] = @belapsed train!($model, $ps, $st, $data; epochs=10)
27+
end
28+
29+
println("\n## FNO (Flux NeuralOperators.jl)")
30+
print("| #layers | Forward | Train: 10 epochs | \n")
31+
print("| --- | --- | --- | \n")
32+
for i in 1:5
33+
print("| $i | $(t_fwd[i] * 1000) ms | $(t_train[i] * 1000) ms | \n")
34+
end
35+
36+
# DeepONets
37+
eval_points = 128
38+
batch_size = 64
39+
dim_y = 1
40+
m = 32
41+
42+
u = rand(Float32, m, batch_size);
43+
y = rand(Float32, dim_y, eval_points, batch_size);
44+
45+
g = rand(Float32, eval_points, batch_size);
46+
47+
data = [((u, y), g)]
48+
t_fwd = zeros(5)
49+
t_train = zeros(5)
50+
for i in 1:5
51+
ch_branch = (m, fill(64, i)..., 128)
52+
ch_trunk = (dim_y, fill(64, i)..., 128)
53+
model = DeepONet(ch_branch, ch_trunk)
54+
# model(u, y) # TTFX
55+
56+
# t_fwd[i] = @belapsed $model($u, $y)
57+
58+
# t_train[i] = @belapsed train!($model, $ps, $st, $data; epochs=10)
59+
end
60+
61+
println("\n## DeepONet (Flux NeuralOperators.jl)")
62+
print("| #layers | Forward | Train: 10 epochs | \n")
63+
print("| --- | --- | --- | \n")
64+
for i in 1:5
65+
print("| $i | $(t_fwd[i] * 1000) ms | $(t_train[i] * 1000) ms | \n")
66+
end

bench/flux/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[deps]
2+
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
3+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
4+
ThreadPinning = "811555cd-349b-4f26-b7bc-1f208b848042"
5+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

bench/lux.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ for i in 1:5
3434
chs = (1, 128, fill(64, i)..., 128, 1)
3535
model = FourierNeuralOperator(gelu; chs=chs, modes=(16,))
3636
ps, st = Lux.setup(rng, model)
37-
_ = model(x, ps, st) # TTFX
37+
model(x, ps, st) # TTFX
3838

3939
t_fwd[i] = @belapsed $model($x, $ps, $st)
4040

4141
t_train[i] = @belapsed train!($model, $ps, $st, $data; epochs=10)
4242
end
4343

44-
println("\n ### FNO")
44+
println("\n## FNO")
4545
print("| #layers | Forward | Train: 10 epochs | \n")
4646
print("| --- | --- | --- | \n")
4747
for i in 1:5
@@ -67,18 +67,16 @@ for i in 1:5
6767
ch_trunk = (dim_y, fill(64, i)..., 128)
6868
model = DeepONet(; branch=ch_branch, trunk=ch_trunk)
6969
ps, st = Lux.setup(rng, model)
70-
_ = model((u, y), ps, st) # TTFX
70+
model((u, y), ps, st) # TTFX
7171

7272
t_fwd[i] = @belapsed $model(($u, $y), $ps, $st)
7373

7474
t_train[i] = @belapsed train!($model, $ps, $st, $data; epochs=10)
7575
end
7676

77-
println("\n ### DeepONet")
77+
println("\n## DeepONet")
7878
print("| #layers | Forward | Train: 10 epochs | \n")
7979
print("| --- | --- | --- | \n")
8080
for i in 1:5
8181
print("| $i | $(t_fwd[i] * 1000) ms | $(t_train[i] * 1000) ms | \n")
8282
end
83-
84-

0 commit comments

Comments
 (0)