diff --git a/examples/nsfnet/VP_NSFNet1.py b/examples/nsfnet/VP_NSFNet1.py index 8e917ab6f2..4c6b0572e0 100644 --- a/examples/nsfnet/VP_NSFNet1.py +++ b/examples/nsfnet/VP_NSFNet1.py @@ -1,6 +1,7 @@ import hydra import numpy as np from omegaconf import DictConfig +import paddle import ppsci from ppsci.utils import logger @@ -24,16 +25,16 @@ def main(cfg: DictConfig): def generate_data(N_TRAIN, lam, seed): - x = np.linspace(-0.5, 1.0, 101) - y = np.linspace(-0.5, 1.5, 101) + x = np.linspace(-0.5, 1.0, 201) + y = np.linspace(-0.5, 1.5, 201) - yb1 = np.array([-0.5] * 100) - yb2 = np.array([1] * 100) - xb1 = np.array([-0.5] * 100) - xb2 = np.array([1.5] * 100) + yb1 = np.array([-0.5] * 200) + yb2 = np.array([1] * 200) + xb1 = np.array([-0.5] * 200) + xb2 = np.array([1.5] * 200) - y_train1 = np.concatenate([y[1:101], y[0:100], xb1, xb2], 0).astype("float32") - x_train1 = np.concatenate([yb1, yb2, x[0:100], x[1:101]], 0).astype("float32") + y_train1 = np.concatenate([y[1:201], y[0:200], xb1, xb2], 0).astype("float32") + x_train1 = np.concatenate([yb1, yb2, x[0:200], x[1:201]], 0).astype("float32") xb_train = x_train1.reshape(x_train1.shape[0], 1).astype("float32") yb_train = y_train1.reshape(y_train1.shape[0], 1).astype("float32") @@ -67,6 +68,8 @@ def generate_data(N_TRAIN, lam, seed): def train(cfg: DictConfig): OUTPUT_DIR = cfg.output_dir logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + paddle.framework.core.set_prim_eager_enabled(True) + paddle.framework.core._set_prim_all_enabled(True) # set random seed for reproducibility SEED = cfg.seed @@ -102,11 +105,12 @@ def train(cfg: DictConfig): p_star, ) = generate_data(N_TRAIN, lam, SEED) + train_dataloader_cfg = { "dataset": { "name": "NamedArrayDataset", "input": {"x": xb_train, "y": yb_train}, - "label": {"u": ub_train, "v": vb_train}, + "label": {"u": ub_train, "v": vb_train, "p": vb_train}, }, "batch_size": NB_TRAIN, "iters_per_epoch": ITERS_PER_EPOCH, @@ -210,47 +214,12 @@ def train(cfg: DictConfig): visualizer=None, eval_with_no_grad=False, output_dir=OUTPUT_DIR, + to_static=True, ) # train model solver.train() - solver.eval() - - # plot the loss - solver.plot_loss_history() - - # set LBFGS optimizer - EPOCHS = 5000 - optimizer = ppsci.optimizer.LBFGS( - max_iter=50000, tolerance_change=np.finfo(float).eps, history_size=50 - )(model) - - logger.init_logger("ppsci", f"{OUTPUT_DIR}/eval.log", "info") - - # initialize solver - solver = ppsci.solver.Solver( - model=model, - constraint=constraint, - optimizer=optimizer, - epochs=EPOCHS, - iters_per_epoch=ITERS_PER_EPOCH, - eval_during_train=False, - log_freq=2000, - eval_freq=2000, - seed=SEED, - equation=equation, - geom=geom, - validator=validator, - visualizer=None, - eval_with_no_grad=False, - output_dir=OUTPUT_DIR, - ) - # train model - solver.train() - - # evaluate after finished training - solver.eval() def evaluate(cfg: DictConfig): diff --git a/ppsci/utils/expression.py b/ppsci/utils/expression.py index b0f5063ec1..d6b95df38c 100644 --- a/ppsci/utils/expression.py +++ b/ppsci/utils/expression.py @@ -13,15 +13,20 @@ # limitations under the License. from __future__ import annotations +import os from typing import TYPE_CHECKING from typing import Callable from typing import Dict from typing import Optional from typing import Tuple +# from jax.jax._src.api import T +import jax +T = jax.jax._src.api.T from paddle import jit from paddle import nn +import paddle if TYPE_CHECKING: import paddle @@ -45,16 +50,21 @@ class ExpressionSolver(nn.Layer): >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128) >>> expr_solver = ExpressionSolver() """ - + def __init__(self): super().__init__() + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass=False + self.train_forward = paddle.jit.to_static(build_strategy=build_strategy, full_graph=True)(self.train_forward) + self.eval_forward = paddle.jit.to_static(build_strategy=build_strategy, full_graph=True)(self.eval_forward) + def forward(self, *args, **kwargs): raise NotImplementedError( "Use train_forward/eval_forward/visu_forward instead of forward." ) - @jit.to_static + def train_forward( self, expr_dicts: Tuple[Dict[str, Callable], ...], @@ -109,7 +119,6 @@ def train_forward( constraint_losses.append(constraint_loss) return constraint_losses - @jit.to_static def eval_forward( self, expr_dict: Dict[str, Callable],