From 2fea541be4de6288a4d4d6dc2cae4fe1a682c70d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 15 May 2024 06:23:32 +0000 Subject: [PATCH 1/3] support default config content in config module and remove deprecated AttrDict series code --- ppsci/arch/__init__.py | 2 +- ppsci/arch/phycrnet.py | 8 +- ppsci/constraint/__init__.py | 2 +- ppsci/data/dataset/__init__.py | 2 +- ppsci/equation/__init__.py | 2 +- ppsci/equation/pde/base.py | 88 +---- ppsci/equation/pde/biharmonic.py | 2 - ppsci/equation/pde/heat_exchanger.py | 2 - ppsci/equation/pde/laplace.py | 2 - ppsci/equation/pde/linear_elasticity.py | 2 - ppsci/equation/pde/navier_stokes.py | 2 - ppsci/equation/pde/nls_m_b.py | 2 - ppsci/equation/pde/normal_dot_vec.py | 2 - ppsci/equation/pde/poisson.py | 2 - ppsci/equation/pde/viv.py | 2 - ppsci/geometry/__init__.py | 2 +- ppsci/loss/__init__.py | 2 +- ppsci/loss/mtl/__init__.py | 2 +- ppsci/metric/__init__.py | 2 +- ppsci/optimizer/__init__.py | 4 +- ppsci/solver/solver.py | 141 +++++--- ppsci/utils/__init__.py | 5 +- ppsci/utils/callbacks.py | 9 +- ppsci/utils/config.py | 411 +++++++++--------------- ppsci/utils/download.py | 2 +- ppsci/utils/symbolic.py | 20 +- ppsci/validate/__init__.py | 2 +- ppsci/visualize/__init__.py | 2 +- 28 files changed, 295 insertions(+), 431 deletions(-) diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index e59cac085c..807ad07db5 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -81,7 +81,7 @@ def build_model(cfg): """Build model Args: - cfg (AttrDict): Arch config. + cfg (DictConfig): Arch config. Returns: nn.Layer: Model. diff --git a/ppsci/arch/phycrnet.py b/ppsci/arch/phycrnet.py index 9cd1fca7cf..9c15f1e0a1 100644 --- a/ppsci/arch/phycrnet.py +++ b/ppsci/arch/phycrnet.py @@ -147,7 +147,7 @@ def __init__( ) # ConvLSTM - self.convlstm = paddle.nn.LayerList( + self.ConvLSTM = paddle.nn.LayerList( [ ConvLSTMCell( input_channels=self.input_channels[i], @@ -194,16 +194,16 @@ def forward(self, x): x = encoder(x) # convlstm - for i, lstm in enumerate(self.convlstm, self.num_encoder): + for i, LSTM in enumerate(self.ConvLSTM): if step == 0: - (h, c) = lstm.init_hidden_tensor( + (h, c) = LSTM.init_hidden_tensor( prev_state=self.initial_state[i - self.num_encoder] ) internal_state.append((h, c)) # one-step forward (h, c) = internal_state[i - self.num_encoder] - x, new_c = lstm(x, h, c) + x, new_c = LSTM(x, h, c) internal_state[i - self.num_encoder] = (x, new_c) # output diff --git a/ppsci/constraint/__init__.py b/ppsci/constraint/__init__.py index 6cbe1a42b0..9179439436 100644 --- a/ppsci/constraint/__init__.py +++ b/ppsci/constraint/__init__.py @@ -42,7 +42,7 @@ def build_constraint(cfg, equation_dict, geom_dict): """Build constraint(s). Args: - cfg (List[AttrDict]): Constraint config list. + cfg (List[DictConfig]): Constraint config list. equation_dict (Dct[str, Equation]): Equation(s) in dict. geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. diff --git a/ppsci/data/dataset/__init__.py b/ppsci/data/dataset/__init__.py index c0eebe860e..960e8a66b9 100644 --- a/ppsci/data/dataset/__init__.py +++ b/ppsci/data/dataset/__init__.py @@ -78,7 +78,7 @@ def build_dataset(cfg) -> "io.Dataset": """Build dataset Args: - cfg (List[AttrDict]): dataset config list. + cfg (List[DictConfig]): dataset config list. Returns: Dict[str, io.Dataset]: dataset. diff --git a/ppsci/equation/__init__.py b/ppsci/equation/__init__.py index 77a9b20860..2b97d378b7 100644 --- a/ppsci/equation/__init__.py +++ b/ppsci/equation/__init__.py @@ -54,7 +54,7 @@ def build_equation(cfg): """Build equation(s) Args: - cfg (List[AttrDict]): Equation(s) config list. + cfg (List[DictConfig]): Equation(s) config list. Returns: Dict[str, Equation]: Equation(s) in dict. diff --git a/ppsci/equation/pde/base.py b/ppsci/equation/pde/base.py index b5affbcf75..9ef55712a3 100644 --- a/ppsci/equation/pde/base.py +++ b/ppsci/equation/pde/base.py @@ -22,7 +22,7 @@ from typing import Union import paddle -import sympy as sp +import sympy from paddle import nn DETACH_FUNC_NAME = "detach" @@ -33,7 +33,7 @@ class PDE: def __init__(self): super().__init__() - self.equations: Dict[str, Union[Callable, sp.Basic]] = {} + self.equations = {} # for PDE which has learnable parameter(s) self.learnable_parameters = nn.ParameterList() @@ -42,7 +42,7 @@ def __init__(self): @staticmethod def create_symbols( symbol_str: str, - ) -> Union[sp.Symbol, Tuple[sp.Symbol, ...]]: + ) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]: """create symbolic variables. Args: @@ -61,9 +61,11 @@ def create_symbols( >>> print(symbols_xyz) (x, y, z) """ - return sp.symbols(symbol_str) + return sympy.symbols(symbol_str) - def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Function: + def create_function( + self, name: str, invars: Tuple[sympy.Symbol, ...] + ) -> sympy.Function: """Create named function depending on given invars. Args: @@ -84,73 +86,14 @@ def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Functi >>> print(f) f(x, y, z) """ - expr = sp.Function(name)(*invars) + expr = sympy.Function(name)(*invars) + # wrap `expression(...)` to `detach(expression(...))` + # if name of expression is in given detach_keys + if self.detach_keys and name in self.detach_keys: + expr = sympy.Function(DETACH_FUNC_NAME)(expr) return expr - def _apply_detach(self): - """ - Wrap detached sub_expr into detach(sub_expr) to prevent gradient - back-propagation, only for those items speicified in self.detach_keys. - - NOTE: This function is expected to be called after self.equations is ready in PDE.__init__. - - Examples: - >>> import ppsci - >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False) - >>> print(ns) - NavierStokes - continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y) - momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2)) - momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2)) - >>> detach_keys = ("u", "v__y") - >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys) - >>> print(ns) - NavierStokes - continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x) - momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2)) - momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2)) - """ - if self.detach_keys is None: - return - - from copy import deepcopy - - from sympy.core.traversal import postorder_traversal - - from ppsci.utils.symbolic import _cvt_to_key - - for name, expr in self.equations.items(): - if not isinstance(expr, sp.Basic): - continue - # only process sympy expression - expr_ = deepcopy(expr) - for item in postorder_traversal(expr): - if _cvt_to_key(item) in self.detach_keys: - # inplace all related sub_expr into detach(sub_expr) - expr_ = expr_.replace(item, sp.Function(DETACH_FUNC_NAME)(item)) - - # remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping - expr_ = expr_.replace( - sp.Function(DETACH_FUNC_NAME)( - sp.Function(DETACH_FUNC_NAME)(item) - ), - sp.Function(DETACH_FUNC_NAME)(item), - ) - - # remove unccessary detach wrapping for the first arg of Derivative - for item_ in list(postorder_traversal(expr_)): - if isinstance(item_, sp.Derivative): - if item_.args[0].name == DETACH_FUNC_NAME: - expr_ = expr_.replace( - item_, - sp.Derivative( - item_.args[0].args[0], *item_.args[1:] - ), - ) - - self.equations[name] = expr_ - def add_equation(self, name: str, equation: Callable): """Add an equation. @@ -167,8 +110,7 @@ def add_equation(self, name: str, equation: Callable): >>> equation = sympy.diff(u, x) + sympy.diff(u, y) >>> pde.add_equation('linear_pde', equation) >>> print(pde) - PDE - linear_pde: 2*x + 2*y + PDE, linear_pde: 2*x + 2*y """ self.equations.update({name: equation}) @@ -239,7 +181,7 @@ def set_state_dict( return self.learnable_parameters.set_state_dict(state_dict) def __str__(self): - return "\n".join( + return ", ".join( [self.__class__.__name__] - + [f" {name}: {eq}" for name, eq in self.equations.items()] + + [f"{name}: {eq}" for name, eq in self.equations.items()] ) diff --git a/ppsci/equation/pde/biharmonic.py b/ppsci/equation/pde/biharmonic.py index 933888ac60..1471c34a6c 100644 --- a/ppsci/equation/pde/biharmonic.py +++ b/ppsci/equation/pde/biharmonic.py @@ -70,5 +70,3 @@ def __init__( biharmonic += u.diff(invar_i, 2).diff(invar_j, 2) self.add_equation("biharmonic", biharmonic) - - self._apply_detach() diff --git a/ppsci/equation/pde/heat_exchanger.py b/ppsci/equation/pde/heat_exchanger.py index c2e0107ff3..d9fd93c224 100644 --- a/ppsci/equation/pde/heat_exchanger.py +++ b/ppsci/equation/pde/heat_exchanger.py @@ -90,5 +90,3 @@ def __init__( self.add_equation("heat_boundary", heat_boundary) self.add_equation("cold_boundary", cold_boundary) self.add_equation("wall", wall) - - self._apply_detach() diff --git a/ppsci/equation/pde/laplace.py b/ppsci/equation/pde/laplace.py index b99d7c8d9a..12b2a03ddd 100644 --- a/ppsci/equation/pde/laplace.py +++ b/ppsci/equation/pde/laplace.py @@ -51,5 +51,3 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): laplace += u.diff(invar, 2) self.add_equation("laplace", laplace) - - self._apply_detach() diff --git a/ppsci/equation/pde/linear_elasticity.py b/ppsci/equation/pde/linear_elasticity.py index 44833f56bf..9120c6d21c 100644 --- a/ppsci/equation/pde/linear_elasticity.py +++ b/ppsci/equation/pde/linear_elasticity.py @@ -179,5 +179,3 @@ def __init__( self.add_equation("traction_y", traction_y) if self.dim == 3: self.add_equation("traction_z", traction_z) - - self._apply_detach() diff --git a/ppsci/equation/pde/navier_stokes.py b/ppsci/equation/pde/navier_stokes.py index c0d3d193a2..41cb819bf9 100644 --- a/ppsci/equation/pde/navier_stokes.py +++ b/ppsci/equation/pde/navier_stokes.py @@ -147,5 +147,3 @@ def __init__( self.add_equation("momentum_y", momentum_y) if self.dim == 3: self.add_equation("momentum_z", momentum_z) - - self._apply_detach() diff --git a/ppsci/equation/pde/nls_m_b.py b/ppsci/equation/pde/nls_m_b.py index 3db2984268..97bf60cabb 100644 --- a/ppsci/equation/pde/nls_m_b.py +++ b/ppsci/equation/pde/nls_m_b.py @@ -97,5 +97,3 @@ def __init__( self.add_equation("Maxwell_1", Maxwell_1) self.add_equation("Maxwell_2", Maxwell_2) self.add_equation("Bloch", Bloch) - - self._apply_detach() diff --git a/ppsci/equation/pde/normal_dot_vec.py b/ppsci/equation/pde/normal_dot_vec.py index a6f3942eeb..de97a140fb 100644 --- a/ppsci/equation/pde/normal_dot_vec.py +++ b/ppsci/equation/pde/normal_dot_vec.py @@ -55,5 +55,3 @@ def __init__( normal_dot_vec += normal * vec self.add_equation("normal_dot_vec", normal_dot_vec) - - self._apply_detach() diff --git a/ppsci/equation/pde/poisson.py b/ppsci/equation/pde/poisson.py index 4f9551a23a..e83fecde05 100644 --- a/ppsci/equation/pde/poisson.py +++ b/ppsci/equation/pde/poisson.py @@ -49,5 +49,3 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): poisson += p.diff(invar, 2) self.add_equation("poisson", poisson) - - self._apply_detach() diff --git a/ppsci/equation/pde/viv.py b/ppsci/equation/pde/viv.py index c3d85895f1..68fd61a446 100644 --- a/ppsci/equation/pde/viv.py +++ b/ppsci/equation/pde/viv.py @@ -60,5 +60,3 @@ def __init__(self, rho: float, k1: float, k2: float): k2 = self.create_symbols(self.k2.name) f = self.rho * eta.diff(t_f, 2) + sp.exp(k1) * eta.diff(t_f) + sp.exp(k2) * eta self.add_equation("f", f) - - self._apply_detach() diff --git a/ppsci/geometry/__init__.py b/ppsci/geometry/__init__.py index 4f1ff0b122..768ed0581d 100644 --- a/ppsci/geometry/__init__.py +++ b/ppsci/geometry/__init__.py @@ -54,7 +54,7 @@ def build_geometry(cfg): """Build geometry(ies) Args: - cfg (List[AttrDict]): Geometry config list. + cfg (List[DictConfig]): Geometry config list. Returns: Dict[str, Geometry]: Geometry(ies) in dict. diff --git a/ppsci/loss/__init__.py b/ppsci/loss/__init__.py index 0035a4193f..8bb9496f68 100644 --- a/ppsci/loss/__init__.py +++ b/ppsci/loss/__init__.py @@ -53,7 +53,7 @@ def build_loss(cfg): """Build loss. Args: - cfg (AttrDict): Loss config. + cfg (DictConfig): Loss config. Returns: Loss: Callable loss object. """ diff --git a/ppsci/loss/mtl/__init__.py b/ppsci/loss/mtl/__init__.py index 35f3b73d90..358efb3609 100644 --- a/ppsci/loss/mtl/__init__.py +++ b/ppsci/loss/mtl/__init__.py @@ -35,7 +35,7 @@ def build_mtl_aggregator(cfg): """Build loss aggregator with multi-task learning method. Args: - cfg (AttrDict): Aggregator config. + cfg (DictConfig): Aggregator config. Returns: Loss: Callable loss aggregator object. """ diff --git a/ppsci/metric/__init__.py b/ppsci/metric/__init__.py index 5390db4c4e..6059b22116 100644 --- a/ppsci/metric/__init__.py +++ b/ppsci/metric/__init__.py @@ -43,7 +43,7 @@ def build_metric(cfg): """Build metric. Args: - cfg (List[AttrDict]): List of metric config. + cfg (List[DictConfig]): List of metric config. Returns: Dict[str, Metric]: Dict of callable metric object. diff --git a/ppsci/optimizer/__init__.py b/ppsci/optimizer/__init__.py index c973b489fb..7dcf33b40b 100644 --- a/ppsci/optimizer/__init__.py +++ b/ppsci/optimizer/__init__.py @@ -39,7 +39,7 @@ def build_lr_scheduler(cfg, epochs, iters_per_epoch): """Build learning rate scheduler. Args: - cfg (AttrDict): Learning rate scheduler config. + cfg (DictConfig): Learning rate scheduler config. epochs (int): Total epochs. iters_per_epoch (int): Number of iterations of one epoch. @@ -57,7 +57,7 @@ def build_optimizer(cfg, model_list, epochs, iters_per_epoch): """Build optimizer and learning rate scheduler Args: - cfg (AttrDict): Learning rate scheduler config. + cfg (DictConfig): Learning rate scheduler config. model_list (Tuple[nn.Layer, ...]): Tuple of model(s). epochs (int): Total epochs. iters_per_epoch (int): Number of iterations of one epoch. diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index f7a00aa8fc..cde418ea57 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -158,12 +158,18 @@ def __init__( cfg: Optional[DictConfig] = None, ): self.cfg = cfg + if isinstance(cfg, DictConfig): + # (Recommended)Params can be passed within cfg + # rather than passed to 'Solver.__init__' one-by-one. + self._parse_params_from_cfg(cfg) + # set model self.model = model # set constraint self.constraint = constraint # set output directory - self.output_dir = output_dir + if not cfg: + self.output_dir = output_dir # set optimizer self.optimizer = optimizer @@ -192,19 +198,20 @@ def __init__( ) # set training hyper-parameter - self.epochs = epochs - self.iters_per_epoch = iters_per_epoch - # set update_freq for gradient accumulation - self.update_freq = update_freq - # set checkpoint saving frequency - self.save_freq = save_freq - # set logging frequency - self.log_freq = log_freq - - # set evaluation hyper-parameter - self.eval_during_train = eval_during_train - self.start_eval_epoch = start_eval_epoch - self.eval_freq = eval_freq + if not cfg: + self.epochs = epochs + self.iters_per_epoch = iters_per_epoch + # set update_freq for gradient accumulation + self.update_freq = update_freq + # set checkpoint saving frequency + self.save_freq = save_freq + # set logging frequency + self.log_freq = log_freq + + # set evaluation hyper-parameter + self.eval_during_train = eval_during_train + self.start_eval_epoch = start_eval_epoch + self.eval_freq = eval_freq # initialize training log(training loss, time cost, etc.) recorder during one epoch self.train_output_info: Dict[str, misc.AverageMeter] = {} @@ -221,21 +228,17 @@ def __init__( "reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"), } - # fix seed for reproducibility - self.seed = seed - # set running device - if device != "cpu" and paddle.device.get_device() == "cpu": + if not cfg: + self.device = device + if self.device != "cpu" and paddle.device.get_device() == "cpu": logger.warning(f"Set device({device}) to 'cpu' for only cpu available.") - device = "cpu" - self.device = paddle.set_device(device) + self.device = "cpu" + self.device = paddle.set_device(self.device) # set equations for physics-driven or data-physics hybrid driven task, such as PINN self.equation = equation - # set geometry for generating data - self.geom = {} if geom is None else geom - # set validator self.validator = validator @@ -243,24 +246,27 @@ def __init__( self.visualizer = visualizer # set automatic mixed precision(AMP) configuration - self.use_amp = use_amp - self.amp_level = amp_level + if not cfg: + self.use_amp = use_amp + self.amp_level = amp_level self.scaler = amp.GradScaler(True) if self.use_amp else None # whether calculate metrics by each batch during evaluation, mainly for memory efficiency - self.compute_metric_by_batch = compute_metric_by_batch + if not cfg: + self.compute_metric_by_batch = compute_metric_by_batch if validator is not None: for metric in itertools.chain( *[_v.metric.values() for _v in self.validator.values()] ): - if metric.keep_batch ^ compute_metric_by_batch: + if metric.keep_batch ^ self.compute_metric_by_batch: raise ValueError( f"{misc.typename(metric)}.keep_batch should be " - f"{compute_metric_by_batch} when compute_metric_by_batch=" - f"{compute_metric_by_batch}." + f"{self.compute_metric_by_batch} when compute_metric_by_batch=" + f"{self.compute_metric_by_batch}." ) # whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation - self.eval_with_no_grad = eval_with_no_grad + if not cfg: + self.eval_with_no_grad = eval_with_no_grad self.rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -278,19 +284,20 @@ def __init__( # set moving average model(optional) self.ema_model = None if self.cfg and any(key in self.cfg.TRAIN for key in ["ema", "swa"]): - if "ema" in self.cfg.TRAIN: - self.avg_freq = self.cfg.TRAIN.ema.avg_freq + if "ema" in self.cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False): self.ema_model = ema.ExponentialMovingAverage( self.model, self.cfg.TRAIN.ema.decay ) - elif "swa" in self.cfg.TRAIN: - self.avg_freq = self.cfg.TRAIN.swa.avg_freq + elif "swa" in self.cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False): self.ema_model = ema.StochasticWeightAverage(self.model) # load pretrained model, usually used for transfer learning - self.pretrained_model_path = pretrained_model_path - if pretrained_model_path is not None: - save_load.load_pretrain(self.model, pretrained_model_path, self.equation) + if not cfg: + self.pretrained_model_path = pretrained_model_path + if self.pretrained_model_path is not None: + save_load.load_pretrain( + self.model, self.pretrained_model_path, self.equation + ) # initialize an dict for tracking best metric during training self.best_metric = { @@ -298,14 +305,16 @@ def __init__( "epoch": 0, } # load model checkpoint, usually used for resume training - if checkpoint_path is not None: - if pretrained_model_path is not None: + if not cfg: + self.checkpoint_path = checkpoint_path + if self.checkpoint_path is not None: + if self.pretrained_model_path is not None: logger.warning( "Detected 'pretrained_model_path' is given, weights in which might be" "overridden by weights loaded from given 'checkpoint_path'." ) loaded_metric = save_load.load_checkpoint( - checkpoint_path, + self.checkpoint_path, self.model, self.optimizer, self.scaler, @@ -366,7 +375,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel: # set VisualDL tool self.vdl_writer = None - if use_vdl: + if not cfg: + self.use_vdl = use_vdl + if self.use_vdl: with misc.RankZeroOnly(self.rank) as is_master: if is_master: self.vdl_writer = vdl.LogWriter(osp.join(output_dir, "vdl")) @@ -377,7 +388,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel: # set WandB tool self.wandb_writer = None - if use_wandb: + if not cfg: + self.use_wandb = use_wandb + if self.use_wandb: try: import wandb except ModuleNotFoundError: @@ -390,7 +403,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel: # set TensorBoardX tool self.tbd_writer = None - if use_tbd: + if not cfg: + self.use_tbd = use_tbd + if self.use_tbd: try: import tensorboardX except ModuleNotFoundError: @@ -984,3 +999,43 @@ def plot_loss_history( smooth_step=smooth_step, use_semilogy=use_semilogy, ) + + def _parse_params_from_cfg(self, cfg: DictConfig): + """ + Parse hyper-parameters from DictConfig. + """ + self.output_dir = cfg.output_dir + self.log_freq = cfg.log_freq + self.use_tbd = cfg.use_tbd + self.use_vdl = cfg.use_vdl + self.wandb_config = cfg.wandb_config + self.use_wandb = cfg.use_wandb + self.device = cfg.device + self.to_static = cfg.to_static + + self.use_amp = cfg.use_amp + self.amp_level = cfg.amp_level + + self.epochs = cfg.TRAIN.epochs + self.iters_per_epoch = cfg.TRAIN.iters_per_epoch + self.update_freq = cfg.TRAIN.update_freq + self.save_freq = cfg.TRAIN.save_freq + self.eval_during_train = cfg.TRAIN.eval_during_train + self.start_eval_epoch = cfg.TRAIN.start_eval_epoch + self.eval_freq = cfg.TRAIN.eval_freq + self.checkpoint_path = cfg.TRAIN.checkpoint_path + + if "ema" in cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False): + self.avg_freq = cfg.TRAIN.ema.avg_freq + elif "swa" in cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False): + self.avg_freq = cfg.TRAIN.swa.avg_freq + + self.compute_metric_by_batch = cfg.EVAL.compute_metric_by_batch + self.eval_with_no_grad = cfg.EVAL.eval_with_no_grad + + if cfg.mode == "train": + self.pretrained_model_path = cfg.TRAIN.pretrained_model_path + elif cfg.mode == "eval": + self.pretrained_model_path = cfg.EVAL.pretrained_model_path + elif cfg.mode in ["export", "infer"]: + self.pretrained_model_path = cfg.INFER.pretrained_model_path diff --git a/ppsci/utils/__init__.py b/ppsci/utils/__init__.py index 5b076fb3bb..f397f090ac 100644 --- a/ppsci/utils/__init__.py +++ b/ppsci/utils/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# NOTE: Put config module import at the top level for register default config(s) in +# ConfigStore at the begining of ppsci +from ppsci.utils import config # isort:skip # noqa: F401 from ppsci.utils import ema from ppsci.utils import initializer from ppsci.utils import logger @@ -22,7 +25,6 @@ from ppsci.utils.checker import dynamic_import_to_globals from ppsci.utils.checker import run_check from ppsci.utils.checker import run_check_mesh -from ppsci.utils.config import AttrDict from ppsci.utils.expression import ExpressionSolver from ppsci.utils.misc import AverageMeter from ppsci.utils.misc import set_random_seed @@ -39,7 +41,6 @@ from ppsci.utils.writer import save_tecplot_file __all__ = [ - "AttrDict", "AverageMeter", "ExpressionSolver", "initializer", diff --git a/ppsci/utils/callbacks.py b/ppsci/utils/callbacks.py index e55a29130f..bcfbbd46bd 100644 --- a/ppsci/utils/callbacks.py +++ b/ppsci/utils/callbacks.py @@ -31,9 +31,10 @@ class InitCallback(Callback): """Callback class for: - 1. Parse config dict from given yaml file and check its validity, complete missing items by its' default values. + 1. Parse config dict from given yaml file and check its validity. 2. Fixing random seed to 'config.seed'. 3. Initialize logger while creating output directory(if not exist). + 4. Enable prim mode if specified. NOTE: This callback is mainly for reducing unnecessary duplicate code in each examples code when runing with hydra. @@ -60,8 +61,6 @@ class InitCallback(Callback): """ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: - # check given cfg using pre-defined pydantic schema in 'SolverConfig', error(s) will be raised - # if any checking failed at this step if importlib.util.find_spec("pydantic") is not None: from pydantic import ValidationError else: @@ -76,8 +75,6 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: # error(s) will be printed and exit program if any checking failed at this step try: _model_pydantic = config_module.SolverConfig(**dict(config)) - # complete missing items with default values pre-defined in pydantic schema in - # 'SolverConfig' full_cfg = DictConfig(_model_pydantic.model_dump()) except ValidationError as e: print(e) @@ -100,7 +97,7 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: # enable prim if specified if "prim" in full_cfg and bool(full_cfg.prim): - # Mostly for dy2st running, will be removed in the future + # Mostly for compiler running with dy2st. from paddle.framework import core core.set_prim_eager_enabled(True) diff --git a/ppsci/utils/config.py b/ppsci/utils/config.py index af28f2e207..0352d2f7ce 100644 --- a/ppsci/utils/config.py +++ b/ppsci/utils/config.py @@ -14,30 +14,77 @@ from __future__ import annotations -import argparse -import copy import importlib.util -import os from typing import Mapping from typing import Optional from typing import Tuple -import yaml -from paddle import static from typing_extensions import Literal -from ppsci.utils import logger -from ppsci.utils import misc - -__all__ = ["get_config", "replace_shape_with_inputspec_", "AttrDict"] +__all__ = [] if importlib.util.find_spec("pydantic") is not None: + from hydra.core.config_store import ConfigStore + from omegaconf import OmegaConf from pydantic import BaseModel from pydantic import field_validator + from pydantic import model_validator from pydantic_core.core_schema import ValidationInfo __all__.append("SolverConfig") + class EMAConfig(BaseModel): + use_ema: bool = False + decay: float = 0.9 + avg_freq: int = 1 + + @field_validator("decay") + def decay_check(cls, v): + if v <= 0 or v >= 1: + raise ValueError( + f"'decay' should be in (0, 1) when is type of float, but got {v}" + ) + return v + + @field_validator("avg_freq") + def avg_freq_check(cls, v): + if v <= 0: + raise ValueError( + "'avg_freq' should be a positive integer when is type of int, " + f"but got {v}" + ) + return v + + class SWAConfig(BaseModel): + use_swa: bool = False + avg_freq: int = 1 + avg_range: Optional[Tuple[int, int]] = None + + @field_validator("avg_range") + def avg_range_check(cls, v, info: ValidationInfo): + if isinstance(v, tuple) and v[0] > v[1]: + raise ValueError(f"'avg_range' should be a valid range, but got {v}.") + if isinstance(v, tuple) and v[0] < 0: + raise ValueError( + "The start epoch of 'avg_range' should be a non-negtive integer" + f" , but got {v[0]}." + ) + if isinstance(v, tuple) and v[1] > info.data["epochs"]: + raise ValueError( + "The end epoch of 'avg_range' should not be lager than " + f"'epochs'({info.data['epochs']}), but got {v[1]}." + ) + return v + + @field_validator("avg_freq") + def avg_freq_check(cls, v): + if v <= 0: + raise ValueError( + "'avg_freq' should be a positive integer when is type of int, " + f"but got {v}" + ) + return v + class TrainConfig(BaseModel): """ Schema of training config for pydantic validation. @@ -55,58 +102,6 @@ class TrainConfig(BaseModel): ema: Optional[EMAConfig] = None swa: Optional[SWAConfig] = None - class EMAConfig(BaseModel): - decay: float = 0.9 - avg_freq: int = 1 - - @field_validator("decay") - def decay_check(cls, v): - if v <= 0 or v >= 1: - raise ValueError( - f"'decay' should be in (0, 1) when is type of float, but got {v}" - ) - return v - - @field_validator("avg_freq") - def avg_freq_check(cls, v): - if v <= 0: - raise ValueError( - "'avg_freq' should be a positive integer when is type of int, " - f"but got {v}" - ) - return v - - class SWAConfig(BaseModel): - avg_freq: int = 1 - avg_range: Optional[Tuple[int, int]] = None - - @field_validator("avg_range") - def avg_range_check(cls, v, info: ValidationInfo): - if v[0] > v[1]: - raise ValueError( - f"'avg_range' should be a valid range, but got {v}." - ) - if v[0] < 0: - raise ValueError( - "The start epoch of 'avg_range' should be a non-negtive integer" - f" , but got {v[0]}." - ) - if v[1] > info.data["epochs"]: - raise ValueError( - "The end epoch of 'avg_range' should not be lager than " - f"'epochs'({info.data['epochs']}), but got {v[1]}." - ) - return v - - @field_validator("avg_freq") - def avg_freq_check(cls, v): - if v <= 0: - raise ValueError( - "'avg_freq' should be a positive integer when is type of int, " - f"but got {v}" - ) - return v - # Fine-grained validator(s) below @field_validator("epochs") def epochs_check(cls, v): @@ -164,21 +159,14 @@ def eval_freq_check(cls, v, info: ValidationInfo): ) return v - @field_validator("ema") - def ema_check(cls, v, info: ValidationInfo): - if "swa" in info.data and info.data["swa"] is not None: - raise ValueError( - "The config of 'swa' should not be used when 'ema' is specifed." - ) - return v - - @field_validator("swa") - def swa_check(cls, v, info: ValidationInfo): - if "ema" in info.data and info.data["ema"] is not None: + @model_validator(mode="after") + def ema_swa_checker(self): + if (self.ema and self.swa) and (self.ema.use_ema and self.swa.use_swa): raise ValueError( - "The config of 'ema' should not be used when 'swa' is specifed." + "Cannot enable both EMA and SWA at the same time, " + "please disable at least one of them." ) - return v + return self class EvalConfig(BaseModel): """ @@ -195,7 +183,7 @@ class InferConfig(BaseModel): """ pretrained_model_path: Optional[str] = None - export_path: str + export_path: str = "./inference" pdmodel_path: Optional[str] = None pdiparams_path: Optional[str] = None onnx_path: Optional[str] = None @@ -284,8 +272,9 @@ class SolverConfig(BaseModel): log_freq: int = 20 seed: int = 42 use_vdl: bool = False - use_wandb: bool = False + use_tbd: bool = False wandb_config: Optional[Mapping] = None + use_wandb: bool = False device: Literal["cpu", "gpu", "xpu"] = "gpu" use_amp: bool = False amp_level: Literal["O0", "O1", "O2", "OD"] = "O1" @@ -320,195 +309,99 @@ def seed_check(cls, v): @field_validator("use_wandb") def use_wandb_check(cls, v, info: ValidationInfo): - if not isinstance(info.data["wandb_config"], dict): + if v and not isinstance(info.data["wandb_config"], dict): raise ValueError( "'wandb_config' should be a dict when 'use_wandb' is True, " - f"but got {misc.typename(info.data['wandb_config'])}" + f"but got {info.data['wandb_config'].__class__.__name__}" ) return v - -class AttrDict(dict): - def __getattr__(self, key): - return self[key] - - def __setattr__(self, key, value): - if key in self.__dict__: - self.__dict__[key] = value - else: - self[key] = value - - def __deepcopy__(self, content): - return AttrDict(copy.deepcopy(dict(self))) - - -def create_attr_dict(yaml_config): - from ast import literal_eval - - for key, value in yaml_config.items(): - if isinstance(value, dict): - yaml_config[key] = value = AttrDict(value) - if isinstance(value, str): - try: - value = literal_eval(value) - except BaseException: - pass - if isinstance(value, AttrDict): - create_attr_dict(yaml_config[key]) - else: - yaml_config[key] = value - - -def parse_config(cfg_file): - """Load a config file into AttrDict""" - with open(cfg_file, "r") as fopen: - yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader)) - create_attr_dict(yaml_config) - return yaml_config - - -def print_dict(d, delimiter=0): - """ - Recursively visualize a dict and - indenting according by the relationship of keys. - """ - placeholder = "-" * 60 - for k, v in d.items(): - if isinstance(v, dict): - logger.info(f"{delimiter * ' '}{k} : ") - print_dict(v, delimiter + 4) - elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): - logger.info(f"{delimiter * ' '}{k} : ") - for value in v: - print_dict(value, delimiter + 2) - else: - logger.info(f"{delimiter * ' '}{k} : {v}") - - if k[0].isupper() and delimiter == 0: - logger.info(placeholder) - - -def print_config(config): - """ - Visualize configs - Arguments: - config: configs - """ - logger.advertise() - print_dict(config) - - -def override(dl, ks, v): + # Register 'XXXConfig' as default node, so as to be used as default config in *.yaml """ - Recursively replace dict of list - Args: - dl(dict or list): dict or list to be replaced - ks(list): list of keys - v(str): value to be replaced + #### xxx.yaml #### + defaults: + - ppsci_default <-- 'ppsci_default' used here + - TRAIN: train_default <-- 'train_default' used here + - TRAIN/ema: ema_default <-- 'ema_default' used here + - TRAIN/swa: swa_default <-- 'swa_default' used here + - EVAL: eval_default <-- 'eval_default' used here + - INFER: infer_default <-- 'infer_default' used here + - _self_ + mode: train + seed: 42 + ... + ... + ################## """ - def str2num(v): - try: - return eval(v) - except Exception: - return v - - if not isinstance(dl, (list, dict)): - raise ValueError(f"{dl} should be a list or a dict") - if len(ks) <= 0: - raise ValueError("length of keys should be larger than 0") - - if isinstance(dl, list): - k = str2num(ks[0]) - if len(ks) == 1: - if k >= len(dl): - raise ValueError(f"index({k}) out of range({dl})") - dl[k] = str2num(v) - else: - override(dl[k], ks[1:], v) - else: - if len(ks) == 1: - # assert ks[0] in dl, (f"{ks[0]} is not exist in {dl}") - if ks[0] not in dl: - print(f"A new field ({ks[0]}) detected!") - dl[ks[0]] = str2num(v) - else: - if ks[0] not in dl.keys(): - dl[ks[0]] = {} - print(f"A new Series field ({ks[0]}) detected!") - override(dl[ks[0]], ks[1:], v) - - -def override_config(config, options=None): - """ - Recursively override the config - Args: - config(dict): dict to be replaced - options(list): list of pairs(key0.key1.idx.key2=value) - such as: [ - "topk=2", - "VALID.transforms.1.ResizeImage.resize_short=300" - ] - Returns: - config(dict): replaced config - """ - if options is not None: - for opt in options: - assert isinstance(opt, str), f"option({opt}) should be a str" - assert ( - "=" in opt - ), f"option({opt}) should contain a = to distinguish between key and value" - pair = opt.split("=") - assert len(pair) == 2, "there can be only a = in the option" - key, value = pair - keys = key.split(".") - override(config, keys, value) - return config - - -def get_config(fname, overrides=None, show=False): - """ - Read config from file - """ - if not os.path.exists(fname): - raise FileNotFoundError(f"config file({fname}) is not exist") - config = parse_config(fname) - override_config(config, overrides) - if show: - print_config(config) - return config - - -def parse_args(): - parser = argparse.ArgumentParser("paddlescience running script") - parser.add_argument("-e", "--epochs", type=int, help="training epochs") - parser.add_argument("-o", "--output_dir", type=str, help="output directory") - parser.add_argument( - "--to_static", - action="store_true", - help="whether enable to_static for forward computation", + cs = ConfigStore.instance() + + global_default_cfg = SolverConfig().model_dump() + omegaconf_dict_config = OmegaConf.create(global_default_cfg) + cs.store(name="ppsci_default", node=omegaconf_dict_config) + + train_default_cfg = TrainConfig().model_dump() + train_omegaconf_dict_config = OmegaConf.create(train_default_cfg) + cs.store(group="TRAIN", name="train_default", node=train_omegaconf_dict_config) + + ema_default_cfg = EMAConfig().model_dump() + ema_omegaconf_dict_config = OmegaConf.create(ema_default_cfg) + cs.store(group="TRAIN/ema", name="ema_default", node=ema_omegaconf_dict_config) + + swa_default_cfg = SWAConfig().model_dump() + swa_omegaconf_dict_config = OmegaConf.create(swa_default_cfg) + cs.store(group="TRAIN/swa", name="swa_default", node=swa_omegaconf_dict_config) + + eval_default_cfg = EvalConfig().model_dump() + eval_omegaconf_dict_config = OmegaConf.create(eval_default_cfg) + cs.store(group="EVAL", name="eval_default", node=eval_omegaconf_dict_config) + + infer_default_cfg = InferConfig().model_dump() + infer_omegaconf_dict_config = OmegaConf.create(infer_default_cfg) + cs.store(group="INFER", name="infer_default", node=infer_omegaconf_dict_config) + + exclude_keys_default = [ + "mode", + "output_dir", + "log_freq", + "seed", + "use_vdl", + "use_tbd", + "wandb_config", + "use_wandb", + "device", + "use_amp", + "amp_level", + "to_static", + "prim", + "log_level", + "TRAIN.save_freq", + "TRAIN.eval_during_train", + "TRAIN.start_eval_epoch", + "TRAIN.eval_freq", + "TRAIN.checkpoint_path", + "TRAIN.pretrained_model_path", + "EVAL.pretrained_model_path", + "EVAL.eval_with_no_grad", + "EVAL.compute_metric_by_batch", + "INFER.pretrained_model_path", + "INFER.export_path", + "INFER.pdmodel_path", + "INFER.pdiparams_path", + "INFER.onnx_path", + "INFER.device", + "INFER.engine", + "INFER.precision", + "INFER.ir_optim", + "INFER.min_subgraph_size", + "INFER.gpu_mem", + "INFER.gpu_id", + "INFER.max_batch_size", + "INFER.num_cpu_threads", + "INFER.batch_size", + ] + cs.store( + group="hydra/job/config/override_dirname/exclude_keys", + name="exclude_keys_default", + node=exclude_keys_default, ) - - args = parser.parse_args() - return args - - -def _is_num_seq(seq): - # whether seq is all int number(it is a shape) - return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq) - - -def replace_shape_with_inputspec_(node: AttrDict): - if _is_num_seq(node): - return True - - if isinstance(node, dict): - for key in node: - if replace_shape_with_inputspec_(node[key]): - node[key] = static.InputSpec(node[key]) - elif isinstance(node, list): - for i in range(len(node)): - if replace_shape_with_inputspec_(node[i]): - node[i] = static.InputSpec(node[i]) - - return False diff --git a/ppsci/utils/download.py b/ppsci/utils/download.py index 4947e4ab3e..a5f00aca62 100644 --- a/ppsci/utils/download.py +++ b/ppsci/utils/download.py @@ -157,7 +157,7 @@ def _download(url, path, md5sum=None): if chunk: f.write(chunk) shutil.move(tmp_fullname, fullname) - logger.message(f"Finish downloading pretrained model and saved to {fullname}") + logger.message(f"Finished downloading pretrained model and saved to {fullname}") return fullname diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index 2ceb09d341..8cf368c8c4 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -40,7 +40,6 @@ __all__ = [ "lambdify", - "_cvt_to_key", ] @@ -117,18 +116,14 @@ def _cvt_to_key(expr: sp.Basic) -> str: Returns: str: Converted string key. """ - if isinstance(expr, sp.Function) and str(expr.func) == equation.DETACH_FUNC_NAME: - return f"{_cvt_to_key(expr.args[0])}_{equation.DETACH_FUNC_NAME}" - if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)): - # use name of custom function(e.g. "f") instead of itself(e.g. "f(x, y)") - # for simplicity. if hasattr(expr, "name"): + # use name of custom function instead of itself. return expr.name else: return str(expr) elif isinstance(expr, sp.Derivative): - # convert "Derivative(u(x,y),(x,2),(y,2))" to "u__x__x__y__y" + # convert Derivative(u(x,y),(x,2),(y,2)) to "u__x__x__y__y" expr_str = expr.args[0].name for symbol, order in expr.args[1:]: expr_str += f"__{symbol}" * order @@ -818,13 +813,12 @@ def _expr_to_callable_nodes( else: callable_nodes.append(OperatorNode(node)) elif isinstance(node, sp.Function): - if str(node.func) == equation.DETACH_FUNC_NAME: + if node.name == equation.DETACH_FUNC_NAME: callable_nodes.append(DetachNode(node)) - logger.debug(f"Detected detach node {node}") else: match_index = None for j, model in enumerate(models): - if str(node.func) in model.output_keys: + if str(node.func.name) in model.output_keys: callable_nodes.append( LayerNode( node, @@ -834,13 +828,13 @@ def _expr_to_callable_nodes( if match_index is not None: raise ValueError( f"Name of function: '{node}' should be unique along given" - f" models, but got same output_key: '{str(node.func)}' " + f" models, but got same output_key: '{node.func.name}' " f"in given models[{match_index}] and models[{j}]." ) match_index = j # NOTE: Skip 'sdf' function, which should be already generated in # given data_dict - if match_index is None and str(node.func) != "sdf": + if match_index is None and node.name != "sdf": raise ValueError( f"Node {node} can not match any model in given model(s)." ) @@ -931,7 +925,7 @@ def _expr_to_callable_nodes( logger.debug( f"Fused {len(candidate_pos)} derivatives nodes: " f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into" - f" {len(fused_node_seq)} fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" + f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" ) # mark merged node diff --git a/ppsci/validate/__init__.py b/ppsci/validate/__init__.py index 9e05b13665..3bc1c9ae4d 100644 --- a/ppsci/validate/__init__.py +++ b/ppsci/validate/__init__.py @@ -33,7 +33,7 @@ def build_validator(cfg, equation_dict, geom_dict): """Build validator(s). Args: - cfg (List[AttrDict]): Validator(s) config list. + cfg (List[DictConfig]): Validator(s) config list. geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. equation_dict (Dct[str, Equation]): Equation(s) in dict. diff --git a/ppsci/visualize/__init__.py b/ppsci/visualize/__init__.py index 7beea234c5..73cd0e0953 100644 --- a/ppsci/visualize/__init__.py +++ b/ppsci/visualize/__init__.py @@ -55,7 +55,7 @@ def build_visualizer(cfg): """Build visualizer(s). Args: - cfg (List[AttrDict]): Visualizer(s) config list. + cfg (List[DictConfig]): Visualizer(s) config list. geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. equation_dict (Dct[str, Equation]): Equation(s) in dict. From d319e8090664c8178b4f0420c80fd87e1ecbb880 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 15 May 2024 06:24:06 +0000 Subject: [PATCH 2/3] update corresponding unitests --- test/utils/test_config.py | 87 ++++++++++++++++----------------------- test/utils/test_writer.py | 12 +++--- 2 files changed, 41 insertions(+), 58 deletions(-) diff --git a/test/utils/test_config.py b/test/utils/test_config.py index 5f650685c8..93b135d944 100644 --- a/test/utils/test_config.py +++ b/test/utils/test_config.py @@ -1,25 +1,17 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import os import hydra import paddle import pytest -from omegaconf import DictConfig +import yaml -paddle.seed(1024) +# 假设你的回调类在这个路径下 +from ppsci.utils.callbacks import InitCallback +# 设置 Paddle 的 seed +paddle.seed(1024) +# 测试函数不需要装饰器 @pytest.mark.parametrize( "epochs,mode,seed", [ @@ -28,42 +20,35 @@ (10, "eval", -1), ], ) -def test_invalid_epochs( - epochs, - mode, - seed, -): - @hydra.main(version_base=None, config_path="./", config_name="test_config.yaml") - def main(cfg: DictConfig): - pass - - # sys.exit will be called when validation error in pydantic, so there we use - # SystemExit instead of other type of errors. - with pytest.raises(SystemExit): - cfg_dict = dict( - { - "TRAIN": { - "epochs": epochs, - }, - "mode": mode, - "seed": seed, - "hydra": { - "callbacks": { - "init_callback": { - "_target_": "ppsci.utils.callbacks.InitCallback" - } - } - }, +def test_invalid_epochs(tmpdir, epochs, mode, seed): + cfg_dict = { + "hydra": { + "callbacks": { + "init_callback": {"_target_": "ppsci.utils.callbacks.InitCallback"} } - ) - # print(cfg_dict) - import yaml - - with open("test_config.yaml", "w") as f: - yaml.dump(dict(cfg_dict), f) - - main() - - + }, + "mode": mode, + "seed": seed, + "TRAIN": { + "epochs": epochs, + }, + } + # 创建一个临时的配置文件 + dir_ = os.path.dirname(__file__) + config_abs_path = os.path.join(dir_, "test_config.yaml") + with open(config_abs_path, "w") as f: + f.write(yaml.dump(cfg_dict)) + + # 使用 hydra 的 compose API 来创建配置,而不是使用 main + with hydra.initialize(config_path="./", version_base=None): + cfg = hydra.compose(config_name="test_config.yaml") + # 手动触发回调 + with pytest.raises(SystemExit) as exec_info: + InitCallback().on_job_start(config=cfg) + assert exec_info.value.code == 2 + # 你现在可以根据需要对 cfg 进行断言或进一步处理 + + +# 这部分通常不需要,除非你想直接从脚本运行测试 if __name__ == "__main__": pytest.main() diff --git a/test/utils/test_writer.py b/test/utils/test_writer.py index 6e960bee28..cce3f69ab8 100644 --- a/test/utils/test_writer.py +++ b/test/utils/test_writer.py @@ -21,13 +21,11 @@ def test_save_csv_file(): keys = ["x1", "y1", "z1"] - alias_dict = ( - { - "x": "x1", - "y": "y1", - "z": "z1", - }, - ) + alias_dict = { + "x": "x1", + "y": "y1", + "z": "z1", + } data_dict = { keys[0]: np.random.randint(0, 255, (10, 1)), keys[1]: np.random.rand(10, 1), From b6d713968b2ecf98e617f74b86ea5f103730861a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 15 May 2024 06:31:06 +0000 Subject: [PATCH 3/3] update develop code --- ppsci/arch/phycrnet.py | 8 +-- ppsci/equation/pde/base.py | 88 ++++++++++++++++++++----- ppsci/equation/pde/biharmonic.py | 2 + ppsci/equation/pde/heat_exchanger.py | 2 + ppsci/equation/pde/laplace.py | 2 + ppsci/equation/pde/linear_elasticity.py | 2 + ppsci/equation/pde/navier_stokes.py | 2 + ppsci/equation/pde/nls_m_b.py | 2 + ppsci/equation/pde/normal_dot_vec.py | 2 + ppsci/equation/pde/poisson.py | 2 + ppsci/equation/pde/viv.py | 2 + ppsci/utils/download.py | 2 +- ppsci/utils/symbolic.py | 20 ++++-- test/utils/test_config.py | 10 +-- 14 files changed, 112 insertions(+), 34 deletions(-) diff --git a/ppsci/arch/phycrnet.py b/ppsci/arch/phycrnet.py index 9c15f1e0a1..9cd1fca7cf 100644 --- a/ppsci/arch/phycrnet.py +++ b/ppsci/arch/phycrnet.py @@ -147,7 +147,7 @@ def __init__( ) # ConvLSTM - self.ConvLSTM = paddle.nn.LayerList( + self.convlstm = paddle.nn.LayerList( [ ConvLSTMCell( input_channels=self.input_channels[i], @@ -194,16 +194,16 @@ def forward(self, x): x = encoder(x) # convlstm - for i, LSTM in enumerate(self.ConvLSTM): + for i, lstm in enumerate(self.convlstm, self.num_encoder): if step == 0: - (h, c) = LSTM.init_hidden_tensor( + (h, c) = lstm.init_hidden_tensor( prev_state=self.initial_state[i - self.num_encoder] ) internal_state.append((h, c)) # one-step forward (h, c) = internal_state[i - self.num_encoder] - x, new_c = LSTM(x, h, c) + x, new_c = lstm(x, h, c) internal_state[i - self.num_encoder] = (x, new_c) # output diff --git a/ppsci/equation/pde/base.py b/ppsci/equation/pde/base.py index 9ef55712a3..b5affbcf75 100644 --- a/ppsci/equation/pde/base.py +++ b/ppsci/equation/pde/base.py @@ -22,7 +22,7 @@ from typing import Union import paddle -import sympy +import sympy as sp from paddle import nn DETACH_FUNC_NAME = "detach" @@ -33,7 +33,7 @@ class PDE: def __init__(self): super().__init__() - self.equations = {} + self.equations: Dict[str, Union[Callable, sp.Basic]] = {} # for PDE which has learnable parameter(s) self.learnable_parameters = nn.ParameterList() @@ -42,7 +42,7 @@ def __init__(self): @staticmethod def create_symbols( symbol_str: str, - ) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]: + ) -> Union[sp.Symbol, Tuple[sp.Symbol, ...]]: """create symbolic variables. Args: @@ -61,11 +61,9 @@ def create_symbols( >>> print(symbols_xyz) (x, y, z) """ - return sympy.symbols(symbol_str) + return sp.symbols(symbol_str) - def create_function( - self, name: str, invars: Tuple[sympy.Symbol, ...] - ) -> sympy.Function: + def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Function: """Create named function depending on given invars. Args: @@ -86,14 +84,73 @@ def create_function( >>> print(f) f(x, y, z) """ - expr = sympy.Function(name)(*invars) + expr = sp.Function(name)(*invars) - # wrap `expression(...)` to `detach(expression(...))` - # if name of expression is in given detach_keys - if self.detach_keys and name in self.detach_keys: - expr = sympy.Function(DETACH_FUNC_NAME)(expr) return expr + def _apply_detach(self): + """ + Wrap detached sub_expr into detach(sub_expr) to prevent gradient + back-propagation, only for those items speicified in self.detach_keys. + + NOTE: This function is expected to be called after self.equations is ready in PDE.__init__. + + Examples: + >>> import ppsci + >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False) + >>> print(ns) + NavierStokes + continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y) + momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2)) + momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2)) + >>> detach_keys = ("u", "v__y") + >>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys) + >>> print(ns) + NavierStokes + continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x) + momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2)) + momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2)) + """ + if self.detach_keys is None: + return + + from copy import deepcopy + + from sympy.core.traversal import postorder_traversal + + from ppsci.utils.symbolic import _cvt_to_key + + for name, expr in self.equations.items(): + if not isinstance(expr, sp.Basic): + continue + # only process sympy expression + expr_ = deepcopy(expr) + for item in postorder_traversal(expr): + if _cvt_to_key(item) in self.detach_keys: + # inplace all related sub_expr into detach(sub_expr) + expr_ = expr_.replace(item, sp.Function(DETACH_FUNC_NAME)(item)) + + # remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping + expr_ = expr_.replace( + sp.Function(DETACH_FUNC_NAME)( + sp.Function(DETACH_FUNC_NAME)(item) + ), + sp.Function(DETACH_FUNC_NAME)(item), + ) + + # remove unccessary detach wrapping for the first arg of Derivative + for item_ in list(postorder_traversal(expr_)): + if isinstance(item_, sp.Derivative): + if item_.args[0].name == DETACH_FUNC_NAME: + expr_ = expr_.replace( + item_, + sp.Derivative( + item_.args[0].args[0], *item_.args[1:] + ), + ) + + self.equations[name] = expr_ + def add_equation(self, name: str, equation: Callable): """Add an equation. @@ -110,7 +167,8 @@ def add_equation(self, name: str, equation: Callable): >>> equation = sympy.diff(u, x) + sympy.diff(u, y) >>> pde.add_equation('linear_pde', equation) >>> print(pde) - PDE, linear_pde: 2*x + 2*y + PDE + linear_pde: 2*x + 2*y """ self.equations.update({name: equation}) @@ -181,7 +239,7 @@ def set_state_dict( return self.learnable_parameters.set_state_dict(state_dict) def __str__(self): - return ", ".join( + return "\n".join( [self.__class__.__name__] - + [f"{name}: {eq}" for name, eq in self.equations.items()] + + [f" {name}: {eq}" for name, eq in self.equations.items()] ) diff --git a/ppsci/equation/pde/biharmonic.py b/ppsci/equation/pde/biharmonic.py index 1471c34a6c..933888ac60 100644 --- a/ppsci/equation/pde/biharmonic.py +++ b/ppsci/equation/pde/biharmonic.py @@ -70,3 +70,5 @@ def __init__( biharmonic += u.diff(invar_i, 2).diff(invar_j, 2) self.add_equation("biharmonic", biharmonic) + + self._apply_detach() diff --git a/ppsci/equation/pde/heat_exchanger.py b/ppsci/equation/pde/heat_exchanger.py index d9fd93c224..c2e0107ff3 100644 --- a/ppsci/equation/pde/heat_exchanger.py +++ b/ppsci/equation/pde/heat_exchanger.py @@ -90,3 +90,5 @@ def __init__( self.add_equation("heat_boundary", heat_boundary) self.add_equation("cold_boundary", cold_boundary) self.add_equation("wall", wall) + + self._apply_detach() diff --git a/ppsci/equation/pde/laplace.py b/ppsci/equation/pde/laplace.py index 12b2a03ddd..b99d7c8d9a 100644 --- a/ppsci/equation/pde/laplace.py +++ b/ppsci/equation/pde/laplace.py @@ -51,3 +51,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): laplace += u.diff(invar, 2) self.add_equation("laplace", laplace) + + self._apply_detach() diff --git a/ppsci/equation/pde/linear_elasticity.py b/ppsci/equation/pde/linear_elasticity.py index 9120c6d21c..44833f56bf 100644 --- a/ppsci/equation/pde/linear_elasticity.py +++ b/ppsci/equation/pde/linear_elasticity.py @@ -179,3 +179,5 @@ def __init__( self.add_equation("traction_y", traction_y) if self.dim == 3: self.add_equation("traction_z", traction_z) + + self._apply_detach() diff --git a/ppsci/equation/pde/navier_stokes.py b/ppsci/equation/pde/navier_stokes.py index 41cb819bf9..c0d3d193a2 100644 --- a/ppsci/equation/pde/navier_stokes.py +++ b/ppsci/equation/pde/navier_stokes.py @@ -147,3 +147,5 @@ def __init__( self.add_equation("momentum_y", momentum_y) if self.dim == 3: self.add_equation("momentum_z", momentum_z) + + self._apply_detach() diff --git a/ppsci/equation/pde/nls_m_b.py b/ppsci/equation/pde/nls_m_b.py index 97bf60cabb..3db2984268 100644 --- a/ppsci/equation/pde/nls_m_b.py +++ b/ppsci/equation/pde/nls_m_b.py @@ -97,3 +97,5 @@ def __init__( self.add_equation("Maxwell_1", Maxwell_1) self.add_equation("Maxwell_2", Maxwell_2) self.add_equation("Bloch", Bloch) + + self._apply_detach() diff --git a/ppsci/equation/pde/normal_dot_vec.py b/ppsci/equation/pde/normal_dot_vec.py index de97a140fb..a6f3942eeb 100644 --- a/ppsci/equation/pde/normal_dot_vec.py +++ b/ppsci/equation/pde/normal_dot_vec.py @@ -55,3 +55,5 @@ def __init__( normal_dot_vec += normal * vec self.add_equation("normal_dot_vec", normal_dot_vec) + + self._apply_detach() diff --git a/ppsci/equation/pde/poisson.py b/ppsci/equation/pde/poisson.py index e83fecde05..4f9551a23a 100644 --- a/ppsci/equation/pde/poisson.py +++ b/ppsci/equation/pde/poisson.py @@ -49,3 +49,5 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None): poisson += p.diff(invar, 2) self.add_equation("poisson", poisson) + + self._apply_detach() diff --git a/ppsci/equation/pde/viv.py b/ppsci/equation/pde/viv.py index 68fd61a446..c3d85895f1 100644 --- a/ppsci/equation/pde/viv.py +++ b/ppsci/equation/pde/viv.py @@ -60,3 +60,5 @@ def __init__(self, rho: float, k1: float, k2: float): k2 = self.create_symbols(self.k2.name) f = self.rho * eta.diff(t_f, 2) + sp.exp(k1) * eta.diff(t_f) + sp.exp(k2) * eta self.add_equation("f", f) + + self._apply_detach() diff --git a/ppsci/utils/download.py b/ppsci/utils/download.py index a5f00aca62..4947e4ab3e 100644 --- a/ppsci/utils/download.py +++ b/ppsci/utils/download.py @@ -157,7 +157,7 @@ def _download(url, path, md5sum=None): if chunk: f.write(chunk) shutil.move(tmp_fullname, fullname) - logger.message(f"Finished downloading pretrained model and saved to {fullname}") + logger.message(f"Finish downloading pretrained model and saved to {fullname}") return fullname diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index 8cf368c8c4..2ceb09d341 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -40,6 +40,7 @@ __all__ = [ "lambdify", + "_cvt_to_key", ] @@ -116,14 +117,18 @@ def _cvt_to_key(expr: sp.Basic) -> str: Returns: str: Converted string key. """ + if isinstance(expr, sp.Function) and str(expr.func) == equation.DETACH_FUNC_NAME: + return f"{_cvt_to_key(expr.args[0])}_{equation.DETACH_FUNC_NAME}" + if isinstance(expr, (sp.Symbol, sp.core.function.UndefinedFunction, sp.Function)): + # use name of custom function(e.g. "f") instead of itself(e.g. "f(x, y)") + # for simplicity. if hasattr(expr, "name"): - # use name of custom function instead of itself. return expr.name else: return str(expr) elif isinstance(expr, sp.Derivative): - # convert Derivative(u(x,y),(x,2),(y,2)) to "u__x__x__y__y" + # convert "Derivative(u(x,y),(x,2),(y,2))" to "u__x__x__y__y" expr_str = expr.args[0].name for symbol, order in expr.args[1:]: expr_str += f"__{symbol}" * order @@ -813,12 +818,13 @@ def _expr_to_callable_nodes( else: callable_nodes.append(OperatorNode(node)) elif isinstance(node, sp.Function): - if node.name == equation.DETACH_FUNC_NAME: + if str(node.func) == equation.DETACH_FUNC_NAME: callable_nodes.append(DetachNode(node)) + logger.debug(f"Detected detach node {node}") else: match_index = None for j, model in enumerate(models): - if str(node.func.name) in model.output_keys: + if str(node.func) in model.output_keys: callable_nodes.append( LayerNode( node, @@ -828,13 +834,13 @@ def _expr_to_callable_nodes( if match_index is not None: raise ValueError( f"Name of function: '{node}' should be unique along given" - f" models, but got same output_key: '{node.func.name}' " + f" models, but got same output_key: '{str(node.func)}' " f"in given models[{match_index}] and models[{j}]." ) match_index = j # NOTE: Skip 'sdf' function, which should be already generated in # given data_dict - if match_index is None and node.name != "sdf": + if match_index is None and str(node.func) != "sdf": raise ValueError( f"Node {node} can not match any model in given model(s)." ) @@ -925,7 +931,7 @@ def _expr_to_callable_nodes( logger.debug( f"Fused {len(candidate_pos)} derivatives nodes: " f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into" - f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" + f" {len(fused_node_seq)} fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" ) # mark merged node diff --git a/test/utils/test_config.py b/test/utils/test_config.py index 93b135d944..844d1f449f 100644 --- a/test/utils/test_config.py +++ b/test/utils/test_config.py @@ -5,13 +5,11 @@ import pytest import yaml -# 假设你的回调类在这个路径下 from ppsci.utils.callbacks import InitCallback -# 设置 Paddle 的 seed paddle.seed(1024) -# 测试函数不需要装饰器 + @pytest.mark.parametrize( "epochs,mode,seed", [ @@ -33,20 +31,18 @@ def test_invalid_epochs(tmpdir, epochs, mode, seed): "epochs": epochs, }, } - # 创建一个临时的配置文件 + dir_ = os.path.dirname(__file__) config_abs_path = os.path.join(dir_, "test_config.yaml") with open(config_abs_path, "w") as f: f.write(yaml.dump(cfg_dict)) - # 使用 hydra 的 compose API 来创建配置,而不是使用 main with hydra.initialize(config_path="./", version_base=None): cfg = hydra.compose(config_name="test_config.yaml") - # 手动触发回调 + with pytest.raises(SystemExit) as exec_info: InitCallback().on_job_start(config=cfg) assert exec_info.value.code == 2 - # 你现在可以根据需要对 cfg 进行断言或进一步处理 # 这部分通常不需要,除非你想直接从脚本运行测试