Skip to content

FNO example#362

Open
ma595 wants to merge 30 commits into
mainfrom
fno_example
Open

FNO example#362
ma595 wants to merge 30 commits into
mainfrom
fno_example

Conversation

@ma595

@ma595 ma595 commented Apr 22, 2025

Copy link
Copy Markdown
Member

This PR introduces a FNO example inspired by Hamid's code here. The example trains a model by sampling a sine wave and then loads and calls the model in Fortran.

I welcome comments on whether this should be merged into FTorch as is, or whether it needs to be adapted to showcase online training. Perhaps @jatkinson1000 can comment on this.

Ready for review once the following is completed:

  • Test Fortran coupling.
  • Ensure we adopt FTorch examples "style" of running with pt2ts first

@ma595 ma595 self-assigned this Apr 22, 2025
@ma595 ma595 marked this pull request as draft April 22, 2025 17:45
@ma595 ma595 marked this pull request as ready for review May 7, 2025 08:21

@joewallwork joewallwork left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution, @ma595! I think it'd be a great addition to our examples. However, adding as an example will require setting up in the CI, too, so I provide some info on how to do that.

I can't comment much on your FNO implementation but I have a few suggested edits related to how we've generally got things set up in FTorch.

Comment thread examples/9_FNO/fno1d.py Outdated

return x

# UNUSED

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can drop it?

@ma595 ma595 May 9, 2025

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above. But yes, this needs resolving before merge.

Comment thread examples/9_FNO/fno1d.py Outdated
Comment on lines +194 to +195
# grid = self.get_grid(u.shape, u.device)
# x = torch.cat((u, grid), dim=-1) # Add grid as extra channel

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These too?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I removed the concatenation of the grid (and its creation) out of this class as I was experimenting with non-uniform grids, just out of curiosity. I provided this code in the generate_parametric_sine_data in fno1d_train.py:

def generate_parametric_sine_data(
batch_size: int = 32,
size_x: int = 32,
random_x: bool = False,
seed: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate a batch of sine waves with varying amplitude, frequency, and phase.
The data consists of a grid of points and their corresponding sine values.
The aim is train the network on a variety of sine waves with different parameters.
The x-values can be either random or evenly spaced. Evenly spaced x-values are
recommended.
Parameters
----------
batch_size : int
Number of sine functions to generate.
size_x : int
Number of spatial points per sample.
random_x : bool
If True, use random (sorted) x-points per sample. Otherwise use linspace.
seed : int or None
Random seed for reproducibility.
Returns
-------
dummy : torch.Tensor
Dummy input of shape (batch_size, size_x, 1)
grid : torch.Tensor
x-values of shape (batch_size, size_x, 1)
target : torch.Tensor
Target u(x) = A * sin(2π * f * x + φ), shape (batch_size, size_x, 1)
"""
rng = np.random.default_rng(seed)
dummy_batch = []
grid_batch = []
target_batch = []
for _ in range(batch_size):
# Generate grid
if random_x:
x = np.sort(rng.uniform(0, 1, size_x))
else:
x = np.linspace(0, 1, size_x)
# Generate sine parameters
A = rng.uniform(0.5, 1.5)
f = rng.uniform(1.0, 3.0)
phi = rng.uniform(0, 2 * np.pi)
y = A * np.sin(2 * np.pi * f * x + phi)
x_tensor = torch.tensor(x, dtype=torch.float32).view(1, size_x, 1)
y_tensor = torch.tensor(y, dtype=torch.float32).view(1, size_x, 1)
dummy_tensor = torch.zeros_like(x_tensor)
grid_batch.append(x_tensor)
target_batch.append(y_tensor)
dummy_batch.append(dummy_tensor)
dummy = torch.cat(dummy_batch, dim=0) # (batch_size, size_x, 1)
grid = torch.cat(grid_batch, dim=0)
target = torch.cat(target_batch, dim=0)
return dummy, grid, target
.

But it isn't used, and I need to check whether it's useful to actually keep. I doubt whether non-uniform grids are that useful in climate applications.

Comment thread examples/9_FNO/CMakeLists.txt Outdated
Comment thread examples/9_FNO/fno1d_infer_fortran.f90 Outdated
Comment thread examples/9_FNO/fno1d_infer_fortran.f90 Outdated
! Infer
call torch_model_forward(model, in_tensors, out_tensors)

! write (*,*) out_data(:)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this, too.

Comment thread examples/9_FNO/fno1d_infer_fortran.f90 Outdated
Comment thread examples/9_FNO/fno1d_infer_fortran.f90 Outdated
Comment thread examples/9_FNO/fno1d_infer_fortran.f90 Outdated
Comment thread examples/9_FNO/fno1d_train.py Outdated

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great that you've included the script for training, especially because we don't currently have this for the other examples. I ran the training script locally and it was pretty fast, so I'd be happy for this to be added to the CI. Besides, we'd need to add this in order to run the other ctests you add.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you'll need the patch

diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 633cd6d..05c0efd 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -15,4 +15,5 @@ if(CMAKE_BUILD_TESTS)
     add_subdirectory(7_MPI)
   endif()
   add_subdirectory(8_Autograd)
+  add_subdirectory(9_FNO)
 endif()
diff --git a/run_test_suite.sh b/run_test_suite.sh
index 9fbac9e..aa09c7a 100755
--- a/run_test_suite.sh
+++ b/run_test_suite.sh
@@ -83,9 +83,9 @@ fi
 # Run integration tests
 if [ "${RUN_INTEGRATION}" = true ]; then
   if [ -e "${BUILD_DIR}/examples/6_MultiGPU" ]; then
-    EXAMPLES="1_Tensor 2_SimpleNet 3_ResNet 4_MultiIO 6_MultiGPU 7_MPI 8_Autograd"
+    EXAMPLES="1_Tensor 2_SimpleNet 3_ResNet 4_MultiIO 6_MultiGPU 7_MPI 8_Autograd 9_FNO"
   else
-    EXAMPLES="1_Tensor 2_SimpleNet 3_ResNet 4_MultiIO 7_MPI 8_Autograd"
+    EXAMPLES="1_Tensor 2_SimpleNet 3_ResNet 4_MultiIO 7_MPI 8_Autograd 9_FNO"
   fi
   export PIP_REQUIRE_VIRTUALENV=true
   for EXAMPLE in ${EXAMPLES}; do

to do this, although these won't pass as-is.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried running python3 fno1d_train.py but got

Traceback (most recent call last):
  File "/home/joe/software/ftorch/examples/9_FNO/fno1d_train.py", line 260, in <module>
    main()
  File "/home/joe/software/ftorch/examples/9_FNO/fno1d_train.py", line 235, in main
    validate()
  File "/home/joe/software/ftorch/examples/9_FNO/fno1d_train.py", line 183, in validate
    loaded_model = torch.jit.load("fno1d_sine.pt")
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/joe/.virtualenvs/ftorch/lib/python3.12/site-packages/torch/jit/_serialization.py", line 153, in load
    raise ValueError(f"The provided filename {f} does not exist")
ValueError: The provided filename fno1d_sine.pt does not exist

It looks like the model needs to either be saved to file or passed to validate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants