diff --git a/docs/zh/examples/confild.md b/docs/zh/examples/confild.md
new file mode 100644
index 0000000000..96e167182c
--- /dev/null
+++ b/docs/zh/examples/confild.md
@@ -0,0 +1,369 @@
+# AI辅助的时空湍流生成:条件神经场潜在扩散模型(CoNFILD)
+
+Distributed under a Creative Commons Attribution license 4.0 (CC BY).
+
+## 1. 背景简介
+### 1.1 论文信息
+| 年份 | 期刊 | 作者 | 引用数 | 论文PDF与补充材料 |
+|----------------|---------------------|--------------------------------------------------------------------------------------------------|--------|----------------------------------------------------------------------------------------------------|
+| 2024年1月3日 | Nature Communications | Pan Du, Meet Hemant Parikh, Xiantao Fan, Xin-Yang Liu, Jian-Xun Wang | 15 | [论文链接](https://doi.org/10.1038/s41467-024-54712-1)
[代码仓库](https://github.com/jx-wang-s-group/CoNFILD) |
+
+### 1.2 作者介绍
+- **通讯作者**:Jian-Xun Wang(王建勋)
所属机构:美国圣母大学航空航天与机械工程系、康奈尔大学机械与航空航天工程系
研究方向:湍流建模、生成式AI、物理信息机器学习
+
+- **其他作者**:
Pan Du、Meet Hemant Parikh(共同一作):圣母大学博士生,研究方向为生成式模型与计算流体力学
Xiantao Fan、Xin-Yang Liu:圣母大学研究助理,负责数值模拟与数据生成
+
+### 1.3 模型&复现代码
+| 问题类型 | 在线运行 | 神经网络架构 | 评估指标 |
+|------------------------|----------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------|
+| 时空湍流生成 | [aistudio](https://aistudio.baidu.com/projectdetail/8933946) | 条件神经场+潜在扩散模型 | MSE: 0.041(速度场) |
+
+=== "模型训练命令"
+```bash
+git clone https://github.com/PaddlePaddle/PaddleScience.git
+cd PaddleScience/examples/confild
+python confild.py mode=train
+```
+
+=== "预训练模型快速评估"
+
+``` sh
+python confild.py mode=eval
+```
+
+## 2. 问题定义
+### 2.1 研究背景
+湍流模拟在航空航天、海洋工程等领域至关重要,但传统方法如直接数值模拟(DNS)和大涡模拟(LES)计算成本高昂,难以应用于高雷诺数或实时场景。现有深度学习模型多基于确定性框架,难以捕捉湍流的混沌特性,且在复杂几何域中表现受限。
+
+### 2.2 核心挑战
+1. **高维数据**:三维时空湍流数据维度高达 \(O(10^9)\),传统生成模型内存需求巨大。
+2. **随机性建模**:需同时捕捉湍流的多尺度统计特性与瞬时动态。
+3. **几何适应性**:需支持不规则计算域与自适应网格。
+
+### 2.3 创新方法
+提出**条件神经场潜在扩散模型(CoNFILD)**,通过三阶段框架解决上述挑战:
+1. **神经场编码**:将高维流场压缩为低维潜在表示,压缩比达0.002%-0.017%。
+2. **潜在扩散**:在潜在空间进行概率扩散过程,学习湍流统计分布。
+3. **零样本条件生成**:结合贝叶斯推理,无需重新训练即可实现传感器重建、超分辨率等任务。
+
+
+*框架示意图:CNF编码器将流场映射到潜在空间,扩散模型生成新潜在样本,解码器重建物理场*
+
+## 3. 模型构建
+### 3.1 条件神经场(CNF)
+- **架构**:基于SIREN网络,采用正弦激活函数捕捉周期性特征。
+- **数学表示**:
+ $$
+ \mathscr{E}(\mathbf{X},\mathbf{L}) = \text{SIREN}(\mathbf{x}) + \text{FILM}(\mathbf{L})
+ $$
+ 其中FILM(Feature-wise Linear Modulation)通过潜在向量\(\mathbf{L}\)调节每层偏置。
+
+### 3.2 潜在扩散模型
+- **前向过程**:逐步添加高斯噪声,潜在表示\(\mathbf{z}_0 \rightarrow \mathbf{z}_T\)。
+- **逆向过程**:训练U-Net预测噪声,通过迭代去噪生成新样本:
+ $$
+ \mathbf{z}_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{z}_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\mathbf{z}_t, t) \right) + \sigma_t \epsilon
+ $$
+
+### 3.3 零样本条件生成
+- **贝叶斯后验采样**:基于稀疏观测\(\Psi\),通过梯度修正潜在空间采样:
+ $$
+ \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t|\Psi) \approx \nabla_{\mathbf{z}_t} \log p(\Psi|\mathbf{z}_t) + \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t)
+ $$
+
+## 4. 问题求解
+### 4.1 数据集准备
+数据文件说明如下:
+```
+data # CNF的训练数据集
+|
+|-- data.npy # 要拟合的数据
+|
+|-- coords.npy # 查询坐标
+```
+
+在加载数据之后,需要进行normalization,以便于训练。具体代码如下:
+```python
+class Normalizer_ts(object):
+ def __init__(self, params=[], method="-11", dim=None):
+ self.params = params
+ self.method = method
+ self.dim = dim
+
+ def fit_normalize(self, data):
+ assert type(data) == paddle.Tensor
+ if len(self.params) == 0:
+ if self.method == "-11" or self.method == "01":
+ if self.dim is None:
+ self.params = paddle.max(x=data), paddle.min(x=data)
+ else:
+ self.params = (
+ paddle.max(keepdim=True, x=data, axis=self.dim),
+ paddle.argmax(keepdim=True, x=data, axis=self.dim),
+ )[0], (
+ paddle.min(keepdim=True, x=data, axis=self.dim),
+ paddle.argmin(keepdim=True, x=data, axis=self.dim),
+ )[
+ 0
+ ]
+ elif self.method == "ms":
+ if self.dim is None:
+ self.params = paddle.mean(x=data, axis=self.dim), paddle.std(
+ x=data, axis=self.dim
+ )
+ else:
+ self.params = paddle.mean(
+ x=data, axis=self.dim, keepdim=True
+ ), paddle.std(x=data, axis=self.dim, keepdim=True)
+ elif self.method == "none":
+ self.params = None
+ return self.fnormalize(data, self.params, self.method)
+
+ def normalize(self, new_data):
+ if not new_data.place == self.params[0].place:
+ self.params = self.params[0].to(new_data.place), self.params[1].to(
+ new_data.place
+ )
+ return self.fnormalize(new_data, self.params, self.method)
+
+ def denormalize(self, new_data_norm):
+ if not new_data_norm.place == self.params[0].place:
+ self.params = self.params[0].to(new_data_norm.place), self.params[1].to(
+ new_data_norm.place
+ )
+ return self.fdenormalize(new_data_norm, self.params, self.method)
+
+ def get_params(self):
+ if self.method == "ms":
+ print("returning mean and std")
+ elif self.method == "01":
+ print("returning max and min")
+ elif self.method == "-11":
+ print("returning max and min")
+ elif self.method == "none":
+ print("do nothing")
+ return self.params
+
+ @staticmethod
+ def fnormalize(data, params, method):
+ if method == "-11":
+ return (data - params[1].to(data.place)) / (
+ params[0].to(data.place) - params[1].to(data.place)
+ ) * 2 - 1
+ elif method == "01":
+ return (data - params[1].to(data.place)) / (
+ params[0].to(data.place) - params[1].to(data.place)
+ )
+ elif method == "ms":
+ return (data - params[0].to(data.place)) / params[1].to(data.place)
+ elif method == "none":
+ return data
+
+ @staticmethod
+ def fdenormalize(data_norm, params, method):
+ if method == "-11":
+ return (data_norm + 1) / 2 * (
+ params[0].to(data_norm.place) - params[1].to(data_norm.place)
+ ) + params[1].to(data_norm.place)
+ elif method == "01":
+ return data_norm * (
+ params[0].to(data_norm.place) - params[1].to(data_norm.place)
+ ) + params[1].to(data_norm.place)
+ elif method == "ms":
+ return data_norm * params[1].to(data_norm.place) + params[0].to(
+ data_norm.place
+ )
+ elif method == "none":
+ return data_norm
+```
+
+### 4.2 CoNFiLD 模型
+CoNFiLD 模型基于贝叶斯后验采样,将稀疏传感器测量数据作为条件输入。通过训练好的无条件扩散模型作为先验,在扩散后验采样过程中,考虑测量噪声引入的不确定性。利用状态到观测映射,根据条件向量与流场的关系,通过调整无条件得分函数,引导生成与传感器数据一致的全时空流场实现重构,并且能提供重构的不确定性估计。代码如下:
+
+```python
+class SIRENAutodecoder_film(paddle.nn.Layer):
+ """
+ siren network with author decoding
+
+ Args:
+ input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict.
+ output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict.
+ in_coord_features (int, optional): Number of input coordinates features
+ in_latent_features (int, optional): Number of input latent features
+ out_features (int, optional): Number of output features
+ num_hidden_layers (int, optional): Number of hidden layers
+ hidden_features (int, optional): Number of hidden features
+ outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False.
+ nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine".
+ weight_init (Callable, optional): Weight initialization function. Defaults to None.
+ bias_init (Callable, optional): Bias initialization function. Defaults to None.
+ premap_mode (str, optional): Feature mapping mode. Defaults to None.
+
+ Examples:
+ >>> model = ppsci.arch.SIRENAutodecoder_film(
+ input_keys=["input1", "input2"],
+ output_keys=("output",),
+ in_coord_features=2,
+ in_latent_features=128,
+ out_features=3,
+ num_hidden_layers=10,
+ hidden_features=128,
+ )
+ >>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])}
+ >>> out_dict = model(input_data)
+ >>> for k, v in out_dict.items():
+ ... print(k, v.shape)
+ output [22, 918, 3]
+ """
+
+ def __init__(
+ self,
+ input_keys,
+ output_keys,
+ in_coord_features,
+ in_latent_features,
+ out_features,
+ num_hidden_layers,
+ hidden_features,
+ outermost_linear=False,
+ nonlinearity="sine",
+ weight_init=None,
+ bias_init=None,
+ premap_mode=None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+
+ self.premap_mode = premap_mode
+ if self.premap_mode is not None:
+ self.premap_layer = FeatureMapping(
+ in_coord_features, mode=premap_mode, **kwargs
+ )
+ in_coord_features = self.premap_layer.dim
+ self.first_layer_init = None
+ self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity]
+ if weight_init is not None:
+ self.weight_init = weight_init
+ else:
+ self.weight_init = nl_weight_init
+ self.net1 = paddle.nn.LayerList(
+ sublayers=[BatchLinear(in_coord_features, hidden_features)]
+ + [
+ BatchLinear(hidden_features, hidden_features)
+ for i in range(num_hidden_layers)
+ ]
+ + [BatchLinear(hidden_features, out_features)]
+ )
+ self.net2 = paddle.nn.LayerList(
+ sublayers=[
+ BatchLinear(in_latent_features, hidden_features, bias_attr=False)
+ for i in range(num_hidden_layers + 1)
+ ]
+ )
+ if self.weight_init is not None:
+ self.net1.apply(self.weight_init)
+ self.net2.apply(self.weight_init)
+ if first_layer_init is not None:
+ self.net1[0].apply(first_layer_init)
+ self.net2[0].apply(first_layer_init)
+ if bias_init is not None:
+ self.net2.apply(bias_init)
+
+ def forward(self, input_data):
+ coords = input_data[self.input_keys[0]]
+ latents = input_data[self.input_keys[1]]
+ if self.premap_mode is not None:
+ x = self.premap_layer(coords)
+ else:
+ x = coords
+
+ for i in range(len(self.net1) - 1):
+ x = self.net1[i](x) + self.net2[i](latents)
+ x = self.nl(x)
+ x = self.net1[-1](x)
+ return {self.output_keys[0]: x}
+
+ def disable_gradient(self):
+ for param in self.parameters():
+ param.stop_gradient = not False
+```
+为了在计算时,准确快速地访问具体变量的值,我们在这里指定网络模型的输入变量名是 ["confild_x", "latent_z"],输出变量名是 ["confild_output"],这些命名与后续代码保持一致。
+
+4.3 模型训练、评估
+完成上述设置之后,只需要将上述实例化的对象按照文档进行组合,然后启动训练、评估。
+```python
+def signal_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer):
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ latents_model = LatentContainer(**cfg.Latent)
+
+ dataset = basic_set(normed_fois, normed_coords)
+ criterion = paddle.nn.MSELoss()
+
+ # set loader
+ train_loader = DataLoader(
+ dataset=dataset, batch_size=cfg.TRAIN.batch_size, shuffle=True
+ )
+ test_loader = DataLoader(
+ dataset=dataset, batch_size=cfg.TRAIN.test_batch_size, shuffle=False
+ )
+ # set optimizer
+ cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model)
+ latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)(
+ latents_model
+ )
+
+ for i in range(cfg.TRAIN.epochs):
+ cnf_model.train()
+ latents_model.train()
+ if i != 0:
+ cnf_optimizer.step()
+ cnf_optimizer.clear_grad(set_to_zero=False)
+ train_loss = []
+ for batch_coords, batch_fois, idx in train_loader:
+ idx = {"latent_x": idx}
+ batch_latent = latents_model(idx)
+ if isinstance(batch_coords, list):
+ batch_coords = [i for i in batch_coords]
+ data = {
+ "confild_x": batch_coords,
+ "latent_z": batch_latent["latent_z"],
+ }
+ batch_output = cnf_model(data)
+ loss = criterion(batch_output["confild_output"], batch_fois)
+ latents_optimizer.clear_grad(set_to_zero=False)
+ loss.backward()
+ latents_optimizer.step()
+ train_loss.append(loss.item())
+ epoch_loss = paddle.stack(x=train_loss).mean()
+ print("epoch {}, train loss {}".format(i + 1, epoch_loss))
+ if i % 100 == 0:
+ test_error = []
+ cnf_model.eval()
+ latents_model.eval()
+ with paddle.no_grad():
+ for test_coords, test_fois, idx in test_loader:
+ if isinstance(test_coords, list):
+ test_coords = [i for i in test_coords]
+ prediction = out_normalizer.denormalize(
+ cnf_model(
+ {
+ "confild_x": test_coords,
+ "latent_z": latents_model({"latent_x": idx})[
+ "latent_z"
+ ],
+ }
+ )
+ )
+ target = out_normalizer.denormalize(test_fois)
+ error = rMAE(prediction=prediction, target=target, dims=spatio_axis)
+ test_error.append(error)
+ test_error = paddle.concat(x=test_error).mean(axis=0)
+ print("test MAE: ", test_error)
+ if i % 1000 == 0:
+ paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams")
+ paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams")
+```
+
+## 5. 实验结果
diff --git a/docs/zh/examples/confild.png b/docs/zh/examples/confild.png
new file mode 100644
index 0000000000..bafb09d9c5
Binary files /dev/null and b/docs/zh/examples/confild.png differ
diff --git a/examples/confild/conf/confild_case1.yaml b/examples/confild/conf/confild_case1.yaml
new file mode 100644
index 0000000000..4774caa4e8
--- /dev/null
+++ b/examples/confild/conf/confild_case1.yaml
@@ -0,0 +1,108 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case1
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 64
+ test_batch_size: 256
+ epochs: 9800
+ mutil_GPU: 1
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case1/confild_case1/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case1/latent_case1/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 10
+ out_features: 3
+ hidden_features: 128
+ in_coord_features: 2
+ in_latent_features: 128
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 16000
+ lumped: True
+ N_features: 128
+ dims: 2
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case1
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case1
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [918, 2]
+ latents_shape: [1, 128]
+ log_freq: 20
+ batch_size: 64
+
+Data:
+ data_path: /home/xinyang/store/projects/nfdiff/algo/elbow/uvp.npy
+ coor_path: /home/xinyang/store/projects/nfdiff/algo/elbow/coor.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_elbow_flow
diff --git a/examples/confild/conf/confild_case2.yaml b/examples/confild/conf/confild_case2.yaml
new file mode 100644
index 0000000000..e41364bbf4
--- /dev/null
+++ b/examples/confild/conf/confild_case2.yaml
@@ -0,0 +1,108 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case2/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case2
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 40
+ test_batch_size: 40
+ epochs: 44500
+ mutil_GPU: 1
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case2/confild_case2/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case2/latent_case2/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 10
+ out_features: 4
+ hidden_features: 256
+ in_coord_features: 2
+ in_latent_features: 256
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 1200
+ lumped: False
+ N_features: 256
+ dims: 2
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case2
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case2
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [400, 100, 2]
+ latents_shape: [1, 1, 256]
+ log_freq: 20
+ batch_size: 40
+
+Data:
+ data_path: /home/xinyang/store/projects/nfdiff/algo/elbow/uvp.npy
+ coor_path: /home/xinyang/store/projects/nfdiff/algo/elbow/coor.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_channel_flow
diff --git a/examples/confild/conf/confild_case3.yaml b/examples/confild/conf/confild_case3.yaml
new file mode 100644
index 0000000000..f15930348d
--- /dev/null
+++ b/examples/confild/conf/confild_case3.yaml
@@ -0,0 +1,108 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case3/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case3
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 100
+ test_batch_size: 100
+ epochs: 4800
+ mutil_GPU: 2
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case3/confild_case3/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case3/latent_case3/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 117
+ out_features: 2
+ hidden_features: 256
+ in_coord_features: 2
+ in_latent_features: 256
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 2880
+ lumped: True
+ N_features: 256
+ dims: 2
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case3
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case3
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [10884, 2]
+ latents_shape: [1, 256]
+ log_freq: 20
+ batch_size: 100
+
+Data:
+ data_path: /home/xinyang/store/projects/nfdiff/algo/elbow/uvp.npy
+ coor_path: /home/xinyang/store/projects/nfdiff/algo/elbow/coor.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_periodic_hill_flow
diff --git a/examples/confild/conf/confild_case4.yaml b/examples/confild/conf/confild_case4.yaml
new file mode 100644
index 0000000000..68e4862768
--- /dev/null
+++ b/examples/confild/conf/confild_case4.yaml
@@ -0,0 +1,108 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case4/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case4
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 4
+ test_batch_size: 4
+ epochs: 20000
+ mutil_GPU: 2
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case4/confild_case4/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case4/latent_case4/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 15
+ out_features: 3
+ hidden_features: 384
+ in_coord_features: 3
+ in_latent_features: 384
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 1200
+ lumped: True
+ N_features: 384
+ dims: 3
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case4
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case4
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [58483, 3]
+ latents_shape: [1, 384]
+ log_freq: 20
+ batch_size: 4
+
+Data:
+ data_path: /home/xinyang/store/projects/nfdiff/algo/elbow/uvp.npy
+ coor_path: /home/xinyang/store/projects/nfdiff/algo/elbow/coor.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_3d_flow
diff --git a/examples/confild/confild.py b/examples/confild/confild.py
new file mode 100644
index 0000000000..d80cb30e53
--- /dev/null
+++ b/examples/confild/confild.py
@@ -0,0 +1,570 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import paddle
+from omegaconf import DictConfig
+from paddle.distributed import fleet
+from paddle.io import DataLoader
+from paddle.io import DistributedBatchSampler
+
+import ppsci
+from ppsci.arch import LatentContainer
+from ppsci.arch import SIRENAutodecoder_film
+from ppsci.utils import logger
+
+
+def load_elbow_flow(path):
+ return np.load(f"{path}")[1:]
+
+
+def load_channel_flow(
+ path,
+ t_start=0,
+ t_end=1200,
+ t_every=1,
+):
+ return np.load(f"{path}")[t_start:t_end:t_every]
+
+
+def load_periodic_hill_flow(path):
+ data = np.load(f"{path}")
+ return data
+
+
+def load_3d_flow(path):
+ data = np.load(f"{path}")
+ return data
+
+
+def rMAE(prediction, target, dims=(1, 2)):
+ return paddle.abs(x=prediction - target).mean(axis=dims) / paddle.abs(
+ x=target
+ ).mean(axis=dims)
+
+
+class Normalizer_ts(object):
+ def __init__(self, params=[], method="-11", dim=None):
+ self.params = params
+ self.method = method
+ self.dim = dim
+
+ def fit_normalize(self, data):
+ assert type(data) == paddle.Tensor
+ if len(self.params) == 0:
+ if self.method == "-11" or self.method == "01":
+ if self.dim is None:
+ self.params = paddle.max(x=data), paddle.min(x=data)
+ else:
+ self.params = (
+ paddle.max(keepdim=True, x=data, axis=self.dim),
+ paddle.argmax(keepdim=True, x=data, axis=self.dim),
+ )[0], (
+ paddle.min(keepdim=True, x=data, axis=self.dim),
+ paddle.argmin(keepdim=True, x=data, axis=self.dim),
+ )[
+ 0
+ ]
+ elif self.method == "ms":
+ if self.dim is None:
+ self.params = paddle.mean(x=data, axis=self.dim), paddle.std(
+ x=data, axis=self.dim
+ )
+ else:
+ self.params = paddle.mean(
+ x=data, axis=self.dim, keepdim=True
+ ), paddle.std(x=data, axis=self.dim, keepdim=True)
+ elif self.method == "none":
+ self.params = None
+ return self.fnormalize(data, self.params, self.method)
+
+ def normalize(self, new_data):
+ if not new_data.place == self.params[0].place:
+ self.params = self.params[0].to(new_data.place), self.params[1].to(
+ new_data.place
+ )
+ return self.fnormalize(new_data, self.params, self.method)
+
+ def denormalize(self, new_data_norm):
+ if not new_data_norm.place == self.params[0].place:
+ self.params = self.params[0].to(new_data_norm.place), self.params[1].to(
+ new_data_norm.place
+ )
+ return self.fdenormalize(new_data_norm, self.params, self.method)
+
+ def get_params(self):
+ if self.method == "ms":
+ print("returning mean and std")
+ elif self.method == "01":
+ print("returning max and min")
+ elif self.method == "-11":
+ print("returning max and min")
+ elif self.method == "none":
+ print("do nothing")
+ return self.params
+
+ @staticmethod
+ def fnormalize(data, params, method):
+ if method == "-11":
+ return (data - params[1].to(data.place)) / (
+ params[0].to(data.place) - params[1].to(data.place)
+ ) * 2 - 1
+ elif method == "01":
+ return (data - params[1].to(data.place)) / (
+ params[0].to(data.place) - params[1].to(data.place)
+ )
+ elif method == "ms":
+ return (data - params[0].to(data.place)) / params[1].to(data.place)
+ elif method == "none":
+ return data
+
+ @staticmethod
+ def fdenormalize(data_norm, params, method):
+ if method == "-11":
+ return (data_norm + 1) / 2 * (
+ params[0].to(data_norm.place) - params[1].to(data_norm.place)
+ ) + params[1].to(data_norm.place)
+ elif method == "01":
+ return data_norm * (
+ params[0].to(data_norm.place) - params[1].to(data_norm.place)
+ ) + params[1].to(data_norm.place)
+ elif method == "ms":
+ return data_norm * params[1].to(data_norm.place) + params[0].to(
+ data_norm.place
+ )
+ elif method == "none":
+ return data_norm
+
+
+# build data
+def getdata(cfg):
+ ###### read data - fois ######
+ if cfg.Data.load_data_fn == "load_3d_flow":
+ input_data = load_3d_flow(cfg.Data.data_path)
+ elif cfg.Data.load_data_fn == "load_elbow_flow":
+ input_data = load_elbow_flow(cfg.Data.data_path)
+ elif cfg.Data.load_data_fn == "load_channel_flow":
+ input_data = load_channel_flow(cfg.Data.data_path)
+ elif cfg.Data.load_data_fn == "load_periodic_hill_flow":
+ input_data = load_periodic_hill_flow(cfg.Data.data_path)
+ else:
+ input_data = np.load(cfg.Data.data_path)
+
+ spatio_shape = input_data.shape[1:-1]
+ spatio_axis = list(
+ range(
+ input_data.ndim if isinstance(input_data, np.ndarray) else input_data.dim()
+ )
+ )[1:-1]
+
+ ###### read data - coordinate ######
+ if cfg.Data.coor_path is None:
+ coord = [np.linspace(0, 1, i) for i in spatio_shape]
+ coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1)
+ else:
+ coord = np.load(cfg.Data.coor_path)
+ coord = coord.astype("float32")
+ input_data = input_data.astype("float32")
+
+ ###### convert to tensor ######
+ input_data = (
+ paddle.to_tensor(input_data)
+ if not isinstance(input_data, paddle.Tensor)
+ else input_data
+ )
+ coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord
+ N_samples = input_data.shape[0]
+
+ ###### normalizer ######
+ in_normalizer = Normalizer_ts(**cfg.Data.normalizer)
+ in_normalizer.fit_normalize(
+ coord if cfg.Latent.lumped else coord.flatten(0, cfg.Latent.dims - 1)
+ )
+ out_normalizer = Normalizer_ts(**cfg.Data.normalizer)
+ out_normalizer.fit_normalize(
+ input_data if cfg.Latent.lumped else input_data.flatten(0, cfg.Latent.dims)
+ )
+ normed_coords = in_normalizer.normalize(coord)
+ normed_fois = out_normalizer.normalize(input_data)
+
+ return normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer
+
+
+class basic_set(paddle.io.Dataset):
+ def __init__(self, fois, coord, extra_siren_in=None) -> None:
+ super().__init__()
+ self.fois = fois
+ self.total_samples = tuple(fois.shape)[0]
+ self.coords = coord
+
+ def __len__(self):
+ return self.total_samples
+
+ def __getitem__(self, idx):
+ if hasattr(self, "extra_in"):
+ extra_id = idx % tuple(self.fois.shape)[1]
+ idb = idx // tuple(self.fois.shape)[1]
+ return (self.coords, self.extra_in[extra_id]), self.fois[idb, extra_id], idx
+ else:
+ return self.coords, self.fois[idx], idx
+
+
+def signal_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer):
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ latents_model = LatentContainer(**cfg.Latent)
+
+ dataset = basic_set(normed_fois, normed_coords)
+ criterion = paddle.nn.MSELoss()
+
+ # set loader
+ train_loader = DataLoader(
+ dataset=dataset, batch_size=cfg.TRAIN.batch_size, shuffle=True
+ )
+ test_loader = DataLoader(
+ dataset=dataset, batch_size=cfg.TRAIN.test_batch_size, shuffle=False
+ )
+ # set optimizer
+ cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model)
+ latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)(
+ latents_model
+ )
+ losses = []
+
+ for i in range(cfg.TRAIN.epochs):
+ cnf_model.train()
+ latents_model.train()
+ if i != 0:
+ cnf_optimizer.step()
+ cnf_optimizer.clear_grad(set_to_zero=False)
+ train_loss = []
+ for batch_coords, batch_fois, idx in train_loader:
+ idx = {"latent_x": idx}
+ batch_latent = latents_model(idx)
+ if isinstance(batch_coords, list):
+ batch_coords = [i for i in batch_coords]
+ data = {
+ "confild_x": batch_coords,
+ "latent_z": batch_latent["latent_z"],
+ }
+ batch_output = cnf_model(data)
+ loss = criterion(batch_output["confild_output"], batch_fois)
+ latents_optimizer.clear_grad(set_to_zero=False)
+ loss.backward()
+ latents_optimizer.step()
+ train_loss.append(loss)
+ epoch_loss = paddle.stack(x=train_loss).mean().item()
+ losses.append(epoch_loss)
+ print("epoch {}, train loss {}".format(i + 1, epoch_loss))
+ if i % 100 == 0:
+ test_error = []
+ cnf_model.eval()
+ latents_model.eval()
+ with paddle.no_grad():
+ for test_coords, test_fois, idx in test_loader:
+ if isinstance(test_coords, list):
+ test_coords = [i for i in test_coords]
+ prediction = out_normalizer.denormalize(
+ cnf_model(
+ {
+ "confild_x": test_coords,
+ "latent_z": latents_model({"latent_x": idx})[
+ "latent_z"
+ ],
+ }
+ )["confild_output"]
+ )
+ target = out_normalizer.denormalize(test_fois)
+ error = rMAE(prediction=prediction, target=target, dims=spatio_axis)
+ test_error.append(error)
+ test_error = paddle.concat(x=test_error).mean(axis=0)
+ print("test MAE: ", test_error)
+ if i % 100 == 0:
+ paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams")
+ paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams")
+ # 绘制损失图
+ plt.figure(figsize=(10, 6))
+ plt.plot(range(cfg.TRAIN.epochs), losses, label='Training Loss')
+
+ # 添加标题和标签
+ plt.title('Training Loss over Epochs')
+ plt.xlabel('Epochs')
+ plt.xticks(rotation=45)
+ plt.ylabel('Loss')
+
+ # 添加图例
+ plt.legend()
+
+ # 显示网格线
+ plt.grid(True)
+
+ # 保存为 PNG 格式
+ plt.savefig('case.png')
+
+ # 显示图形
+ plt.show()
+
+
+
+def mutil_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer):
+ fleet.init(is_collective=True)
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ cnf_model = fleet.distributed_model(cnf_model)
+ latents_model = LatentContainer(**cfg.Latent)
+ latents_model = fleet.distributed_model(latents_model)
+
+ # set optimizer
+ cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model)
+ cnf_optimizer = fleet.distributed_optimizer(cnf_optimizer)
+ latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)(
+ latents_model
+ )
+ latents_optimizer = fleet.distributed_optimizer(latents_optimizer)
+
+ dataset = basic_set(normed_fois, normed_coords)
+
+ train_sampler = DistributedBatchSampler(
+ dataset, cfg.Train.batch_size, shuffle=True, drop_last=True
+ )
+ train_loader = DataLoader(
+ dataset,
+ batch_sampler=train_sampler,
+ shuffle=True,
+ num_workers=cfg.TRAIN.mutil_GPU,
+ )
+ test_sampler = DistributedBatchSampler(
+ dataset, cfg.Train.test_batch_size, drop_last=True
+ )
+ test_loader = DataLoader(
+ dataset,
+ batch_sampler=test_sampler,
+ shuffle=False,
+ num_workers=cfg.TRAIN.mutil_GPU,
+ )
+
+ criterion = paddle.nn.MSELoss()
+ losses = []
+
+ for i in range(cfg.TRAIN.epochs):
+ cnf_model.train()
+ latents_model.train()
+ if i != 0:
+ cnf_optimizer.step()
+ cnf_optimizer.clear_grad(set_to_zero=False)
+ train_loss = []
+ for batch_coords, batch_fois, idx in train_loader:
+ idx = {"latent_x": idx}
+ batch_latent = latents_model(idx)
+ if isinstance(batch_coords, list):
+ batch_coords = [i for i in batch_coords]
+ data = {
+ "confild_x": batch_coords,
+ "latent_z": batch_latent["latent_z"],
+ }
+ batch_output = cnf_model(data)
+ loss = criterion(batch_output["confild_output"], batch_fois)
+ latents_optimizer.clear_grad(set_to_zero=False)
+ loss.backward()
+ latents_optimizer.step()
+ train_loss.append(loss)
+ epoch_loss = paddle.stack(x=train_loss).mean().item()
+ losses.append(epoch_loss)
+ print("epoch {}, train loss {}".format(i + 1, epoch_loss))
+ if i % 100 == 0:
+ test_error = []
+ cnf_model.eval()
+ latents_model.eval()
+ with paddle.no_grad():
+ for test_coords, test_fois, idx in test_loader:
+ if isinstance(test_coords, list):
+ test_coords = [i for i in test_coords]
+ prediction = out_normalizer.denormalize(
+ cnf_model(
+ {
+ "confild_x": test_coords,
+ "latent_z": latents_model({"latent_x": idx})[
+ "latent_z"
+ ],
+ }
+ )["confild_output"]
+ )
+ target = out_normalizer.denormalize(test_fois)
+ error = rMAE(prediction=prediction, target=target, dims=spatio_axis)
+ test_error.append(error)
+ test_error = paddle.concat(x=test_error).mean(axis=0)
+ print("test MAE: ", test_error)
+ if i % 100 == 0:
+ paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams")
+ paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams")
+ # 绘制损失图
+ plt.figure(figsize=(10, 6))
+ plt.plot(range(cfg.TRAIN.epochs), losses, label='Training Loss')
+
+ # 添加标题和标签
+ plt.title('Training Loss over Epochs')
+ plt.xlabel('Epochs')
+ plt.xticks(rotation=45)
+ plt.ylabel('Loss')
+
+ # 添加图例
+ plt.legend()
+
+ # 显示网格线
+ plt.grid(True)
+
+ # 保存为 PNG 格式
+ plt.savefig('case.png')
+
+ # 显示图形
+ plt.show()
+
+
+def train(cfg):
+ normed_coords, normed_fois, _, spatio_axis, out_normalizer = getdata(cfg)
+ if cfg.TRAIN.mutil_GPU > 1:
+ mutil_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer)
+ else:
+ signal_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer)
+
+
+def evaluate(cfg: DictConfig):
+ # set data
+ normed_coords, normed_fois, _, spatio_axis, out_normalizer = getdata(cfg)
+
+ if len(normed_coords.shape) + 1 == len(normed_fois.shape):
+ normed_coords = paddle.tile(
+ normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape)
+ )
+
+ idx = paddle.to_tensor(
+ np.array([i for i in range(normed_fois.shape[0])]), dtype="int64"
+ )
+ # set model
+ confild = SIRENAutodecoder_film(**cfg.CONFILD)
+ latent = LatentContainer(**cfg.Latent)
+ logger.info(
+ "Loading pretrained model from {}".format(
+ cfg.EVAL.confild_pretrained_model_path
+ )
+ )
+ ppsci.utils.save_load.load_pretrain(
+ confild,
+ cfg.EVAL.confild_pretrained_model_path,
+ )
+ logger.info(
+ "Loading pretrained model from {}".format(cfg.EVAL.latent_pretrained_model_path)
+ )
+ ppsci.utils.save_load.load_pretrain(
+ latent,
+ cfg.EVAL.latent_pretrained_model_path,
+ )
+ latent_test_pred = latent({"latent_x": idx})
+ y_test_pred = []
+ for i in range(normed_coords.shape[0]):
+ y_test_pred.append(
+ confild(
+ {
+ "confild_x": normed_coords[i],
+ "latent_z": latent_test_pred["latent_z"][i],
+ }
+ )["confild_output"].numpy()
+ )
+ y_test_pred = paddle.to_tensor(np.array(y_test_pred))
+
+ y_test_pred = out_normalizer.denormalize(y_test_pred)
+ y_test = out_normalizer.denormalize(normed_fois)
+ logger.info("Result is {}".format(y_test.numpy()))
+
+
+def inference(cfg):
+ normed_coords, normed_fois, _, _, _ = getdata(cfg)
+ if len(normed_coords.shape) + 1 == len(normed_fois.shape):
+ normed_coords = paddle.tile(
+ normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape)
+ )
+
+ fois_len = normed_fois.shape[0]
+ idxs = np.array([i for i in range(fois_len)])
+ from deploy import python_infer
+
+ latent_predictor = python_infer.GeneralPredictor(cfg.INFER.Latent)
+ input_dict = {"latent_x": idxs}
+ output_dict = latent_predictor.predict(input_dict, cfg.INFER.batch_size)
+
+ cnf_predictor = python_infer.GeneralPredictor(cfg.INFER.Confild)
+ input_dict = {
+ "confild_x": normed_coords.numpy(),
+ "latent_z": list(output_dict.values())[0],
+ }
+ output_dict = cnf_predictor.predict(input_dict, cfg.INFER.batch_size)
+
+ logger.info("Result is {}".format(output_dict["confild_output"]) )
+
+
+def export(cfg):
+ # set model
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ latent_model = LatentContainer(**cfg.Latent)
+ # initialize solver
+ latnet_solver = ppsci.solver.Solver(
+ latent_model,
+ pretrained_model_path=cfg.INFER.Latent.INFER.pretrained_model_path,
+ )
+ cnf_solver = ppsci.solver.Solver(
+ cnf_model,
+ pretrained_model_path=cfg.INFER.Confild.INFER.pretrained_model_path,
+ )
+ # export model
+ from paddle.static import InputSpec
+
+ input_spec = [
+ {key: InputSpec([None], "int64", name=key) for key in latent_model.input_keys},
+ ]
+ cnf_input_spec = [
+ {
+ cnf_model.input_keys[0]: InputSpec(
+ [None] + list(cfg.INFER.Confild.INFER.coord_shape),
+ "float32",
+ name=cnf_model.input_keys[0],
+ ),
+ cnf_model.input_keys[1]: InputSpec(
+ [None] + list(cfg.INFER.Confild.INFER.latents_shape),
+ "float32",
+ name=cnf_model.input_keys[1],
+ ),
+ }
+ ]
+ cnf_solver.export(cnf_input_spec, cfg.INFER.Confild.INFER.export_path)
+ latnet_solver.export(input_spec, cfg.INFER.Latent.INFER.export_path)
+
+
+@hydra.main(version_base=None, config_path="./conf", config_name="confild_case1.yaml")
+def main(cfg: DictConfig):
+ if cfg.mode == "train":
+ train(cfg)
+ elif cfg.mode == "eval":
+ evaluate(cfg)
+ elif cfg.mode == "infer":
+ inference(cfg)
+ elif cfg.mode == "export":
+ export(cfg)
+ else:
+ raise ValueError(
+ f"cfg.mode should in ['train', 'eval', 'infer', 'export'], but got '{cfg.mode}'"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py
index ec3f63597c..1b108d3446 100644
--- a/ppsci/arch/__init__.py
+++ b/ppsci/arch/__init__.py
@@ -22,6 +22,7 @@
from ppsci.arch.base import Arch # isort:skip
from ppsci.arch.cfdgcn import CFDGCN # isort:skip
from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip
+from ppsci.arch.confild import LatentContainer, SIRENAutodecoder_film # isort:skip
from ppsci.arch.crystalgraphconvnet import CrystalGraphConvNet # isort:skip
from ppsci.arch.cuboid_transformer import CuboidTransformer # isort:skip
from ppsci.arch.cvit import CVit # isort:skip
@@ -88,6 +89,7 @@
"Generator",
"GraphCastNet",
"HEDeepONets",
+ "LatentContainer",
"LorenzEmbedding",
"LNO",
"MLP",
@@ -100,6 +102,7 @@
"PrecipNet",
"RosslerEmbedding",
"SFNONet",
+ "SIRENAutodecoder_film",
"SPINN",
"TFNO1dNet",
"TFNO2dNet",
diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py
new file mode 100644
index 0000000000..e68e8cdfd7
--- /dev/null
+++ b/ppsci/arch/confild.py
@@ -0,0 +1,376 @@
+import math
+from collections import OrderedDict
+
+import numpy as np
+import paddle
+
+DEFAULT_W0 = 30.0
+
+
+class Swish(paddle.nn.Layer):
+ def __init__(self):
+ super().__init__()
+ self.Sigmoid = paddle.nn.Sigmoid()
+
+ def forward(self, x):
+ return x * self.Sigmoid(x)
+
+
+class Sine(paddle.nn.Layer):
+ def __init__(self, w0=DEFAULT_W0):
+ self.w0 = w0
+ super().__init__()
+
+ def forward(self, input):
+ return paddle.sin(x=self.w0 * input)
+
+
+def sine_init(m, w0=DEFAULT_W0):
+ with paddle.no_grad():
+ if hasattr(m, "weight"):
+ num_input = m.weight.shape[-1]
+ m.weight.uniform_(
+ min=-math.sqrt(6 / num_input) / w0, max=math.sqrt(6 / num_input) / w0
+ )
+
+
+def first_layer_sine_init(m):
+ with paddle.no_grad():
+ if hasattr(m, "weight"):
+ num_input = m.weight.shape[-1]
+ m.weight.uniform_(min=-1 / num_input, max=1 / num_input)
+
+
+def __check_Linear_weight(m):
+ if isinstance(m, paddle.nn.Linear):
+ if hasattr(m, "weight"):
+ return True
+ return False
+
+
+def init_weights_normal(m):
+ if __check_Linear_weight(m):
+ init_KaimingNormal = paddle.nn.initializer.KaimingNormal(
+ nonlinearity="relu", negative_slope=0.0
+ )
+ init_KaimingNormal(m.weight)
+
+
+def init_weights_selu(m):
+ if __check_Linear_weight(m):
+ num_input = m.weight.shape[-1]
+ init_Normal = paddle.nn.initializer.Normal(std=1 / math.sqrt(num_input))
+ init_Normal(m.weight)
+
+
+def init_weights_elu(m):
+ if __check_Linear_weight(m):
+ num_input = m.weight.shape[-1]
+ init_Normal = paddle.nn.initializer.Normal(
+ std=math.sqrt(1.5505188080679277) / math.sqrt(num_input)
+ )
+ init_Normal(m.weight)
+
+
+def init_weights_xavier(m):
+ if __check_Linear_weight(m):
+ init_XavierNormal = paddle.nn.initializer.XavierNormal()
+ init_XavierNormal(m.weight)
+
+
+NLS_AND_INITS = {
+ "sine": (Sine(), sine_init, first_layer_sine_init),
+ "relu": (paddle.nn.ReLU(), init_weights_normal, None),
+ "sigmoid": (paddle.nn.Sigmoid(), init_weights_xavier, None),
+ "tanh": (paddle.nn.Tanh(), init_weights_xavier, None),
+ "selu": (paddle.nn.SELU(), init_weights_selu, None),
+ "softplus": (paddle.nn.Softplus(), init_weights_normal, None),
+ "elu": (paddle.nn.ELU(), init_weights_elu, None),
+ "swish": (Swish(), init_weights_xavier, None),
+}
+
+
+class BatchLinear(paddle.nn.Linear):
+ """
+ This is a linear transformation implemented manually. It also allows maually input parameters.
+ for initialization, (in_features, out_features) needs to be provided.
+ weight is of shape (out_features*in_features)
+ bias is of shape (out_features)
+
+ """
+
+ __doc__ = paddle.nn.Linear.__doc__
+
+ def forward(self, input, params=None):
+ if params is None:
+ params = OrderedDict(self.named_parameters())
+ bias = params.get("bias", None)
+ weight = params["weight"]
+
+ output = paddle.matmul(x=input, y=weight)
+ if bias is not None:
+ output += bias.unsqueeze(axis=-2)
+ return output
+
+
+class FeatureMapping:
+ """
+ This is feature mapping class for fourier feature networks
+ """
+
+ def __init__(
+ self,
+ in_features,
+ mode="basic",
+ gaussian_mapping_size=256,
+ gaussian_rand_key=0,
+ gaussian_tau=1.0,
+ pe_num_freqs=4,
+ pe_scale=2,
+ pe_init_scale=1,
+ pe_use_nyquist=True,
+ pe_lowest_dim=None,
+ rbf_out_features=None,
+ rbf_range=1.0,
+ rbf_std=0.5,
+ ):
+ """
+ inputs:
+ in_freatures: number of input features
+ mapping_size: output features for Gaussian mapping
+ rand_key: random key for Gaussian mapping
+ tau: standard deviation for Gaussian mapping
+ num_freqs: number of frequencies for P.E.
+ scale = 2: base scale of frequencies for P.E.
+ init_scale: initial scale for P.E.
+ use_nyquist: use nyquist to calculate num_freqs or not.
+
+ """
+ self.mode = mode
+ if mode == "basic":
+ self.B = np.eye(in_features)
+ elif mode == "gaussian":
+ rng = np.random.default_rng(gaussian_rand_key)
+ self.B = rng.normal(
+ loc=0.0, scale=gaussian_tau, size=(gaussian_mapping_size, in_features)
+ )
+ elif mode == "positional":
+ if pe_use_nyquist == "True" and pe_lowest_dim:
+ pe_num_freqs = self.get_num_frequencies_nyquist(pe_lowest_dim)
+ self.B = pe_init_scale * np.vstack(
+ [(pe_scale**i * np.eye(in_features)) for i in range(pe_num_freqs)]
+ )
+ self.dim = tuple(self.B.shape)[0] * 2
+ elif mode == "rbf":
+ self.centers = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.empty(
+ shape=(rbf_out_features, in_features), dtype="float32"
+ )
+ )
+ self.sigmas = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.empty(shape=rbf_out_features, dtype="float32")
+ )
+ init_Uniform = paddle.nn.initializer.Uniform(
+ low=-1 * rbf_range, high=rbf_range
+ )
+ init_Uniform(self.centers)
+ init_Constant = paddle.nn.initializer.Constant(value=rbf_std)
+ init_Constant(self.sigmas)
+
+ def __call__(self, input):
+ if self.mode in ["basic", "gaussian", "positional"]:
+ return self.fourier_mapping(input, self.B)
+ elif self.mode == "rbf":
+ return self.rbf_mapping(input)
+
+ def get_num_frequencies_nyquist(self, samples):
+ nyquist_rate = 1 / (2 * (2 * 1 / samples))
+ return int(math.floor(math.log(nyquist_rate, 2)))
+
+ @staticmethod
+ def fourier_mapping(x, B):
+ """
+ x is the input, B is the reference information
+ """
+ if B is None:
+ return x
+ else:
+ B = paddle.to_tensor(data=B, dtype="float32", place=x.place)
+ x_proj = 2.0 * np.pi * x @ B.T
+ return paddle.concat(
+ x=[paddle.sin(x=x_proj), paddle.cos(x=x_proj)], axis=-1
+ )
+
+ def rbf_mapping(self, x):
+ size = tuple(x.shape)[:-1] + tuple(self.centers.shape)
+ x = x.unsqueeze(axis=-2).expand(shape=size)
+ distances = (x - self.centers).pow(y=2).sum(axis=-1) * self.sigmas
+ return self.gaussian(distances)
+
+ @staticmethod
+ def gaussian(alpha):
+ phi = paddle.exp(x=-1 * alpha.pow(y=2))
+ return phi
+
+
+class SIRENAutodecoder_film(paddle.nn.Layer):
+ """
+ siren network with author decoding
+
+ Args:
+ input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict.
+ output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict.
+ in_coord_features (int, optional): Number of input coordinates features
+ in_latent_features (int, optional): Number of input latent features
+ out_features (int, optional): Number of output features
+ num_hidden_layers (int, optional): Number of hidden layers
+ hidden_features (int, optional): Number of hidden features
+ outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False.
+ nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine".
+ weight_init (Callable, optional): Weight initialization function. Defaults to None.
+ bias_init (Callable, optional): Bias initialization function. Defaults to None.
+ premap_mode (str, optional): Feature mapping mode. Defaults to None.
+
+ Examples:
+ >>> model = ppsci.arch.SIRENAutodecoder_film(
+ input_keys=["input1", "input2"],
+ output_keys=("output",),
+ in_coord_features=2,
+ in_latent_features=128,
+ out_features=3,
+ num_hidden_layers=10,
+ hidden_features=128,
+ )
+ >>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])}
+ >>> out_dict = model(input_data)
+ >>> for k, v in out_dict.items():
+ ... print(k, v.shape)
+ output [22, 918, 3]
+ """
+
+ def __init__(
+ self,
+ input_keys,
+ output_keys,
+ in_coord_features,
+ in_latent_features,
+ out_features,
+ num_hidden_layers,
+ hidden_features,
+ outermost_linear=False,
+ nonlinearity="sine",
+ weight_init=None,
+ bias_init=None,
+ premap_mode=None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+
+ self.premap_mode = premap_mode
+ if self.premap_mode is not None:
+ self.premap_layer = FeatureMapping(
+ in_coord_features, mode=premap_mode, **kwargs
+ )
+ in_coord_features = self.premap_layer.dim
+ self.first_layer_init = None
+ self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity]
+ if weight_init is not None:
+ self.weight_init = weight_init
+ else:
+ self.weight_init = nl_weight_init
+ self.net1 = paddle.nn.LayerList(
+ sublayers=[BatchLinear(in_coord_features, hidden_features)]
+ + [
+ BatchLinear(hidden_features, hidden_features)
+ for i in range(num_hidden_layers)
+ ]
+ + [BatchLinear(hidden_features, out_features)]
+ )
+ self.net2 = paddle.nn.LayerList(
+ sublayers=[
+ BatchLinear(in_latent_features, hidden_features, bias_attr=False)
+ for i in range(num_hidden_layers + 1)
+ ]
+ )
+ if self.weight_init is not None:
+ self.net1.apply(self.weight_init)
+ self.net2.apply(self.weight_init)
+ if first_layer_init is not None:
+ self.net1[0].apply(first_layer_init)
+ self.net2[0].apply(first_layer_init)
+ if bias_init is not None:
+ self.net2.apply(bias_init)
+
+ def forward(self, input_data):
+ coords = input_data[self.input_keys[0]]
+ latents = input_data[self.input_keys[1]]
+ if self.premap_mode is not None:
+ x = self.premap_layer(coords)
+ else:
+ x = coords
+
+ for i in range(len(self.net1) - 1):
+ x = self.net1[i](x) + self.net2[i](latents)
+ x = self.nl(x)
+ x = self.net1[-1](x)
+ return {self.output_keys[0]: x}
+
+ def disable_gradient(self):
+ for param in self.parameters():
+ param.stop_gradient = not False
+
+
+class LatentContainer(paddle.nn.Layer):
+ """
+ a model container that stores latents for multi GPU
+
+ Args:
+ input_key (Tuple[str, ...], optional): Key to get the input tensor from the dict. Defaults to ("intput",).
+ output_key (Tuple[str, ...], optional): Key to save the output tensor into the dict. Defaults to ("output",).
+ N_samples (int, optional): Number of samples. Defaults to None.
+ N_features (int, optional): Number of features. Defaults to None.
+ dims (int, optional): Number of dimensions. Defaults to None.
+ lumped (bool, optional): Whether to lump the latents. Defaults to False.
+
+ Examples:
+ >>> model = ppsci.arch.LatentContainer(N_samples=1600, N_features=128, dims=2, lumped=True)
+ >>> input_data = paddle.linspace(0, 1600, 1600, 'int64')
+ >>> input_dict = {"input": input_data}
+ >>> out_dict = model(input_dict)
+ >>> for k, v in out_dict.items():
+ ... print(k, v.shape)
+ output [1600, 1, 128]
+ """
+
+ def __init__(
+ self,
+ input_keys=("input",),
+ output_keys=("output",),
+ N_samples=None,
+ N_features=None,
+ dims=None,
+ lumped=False,
+ ):
+ super().__init__()
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+ self.dims = [1] * dims if not lumped else [1]
+ self.expand_dims = " ".join(["1" for _ in range(dims)]) if not lumped else "1"
+ self.expand_dims = f"N f -> N {self.expand_dims} f"
+ self.latents = self.create_parameter(
+ shape=(N_samples, N_features),
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(0.0),
+ )
+
+ def forward(self, batch_ids):
+ x = batch_ids[self.input_keys[0]]
+ selected_latents = paddle.gather(self.latents, x)
+ if len(selected_latents.shape) > 1:
+ getShape = [tuple(selected_latents.shape)[0]] + self.dims + [tuple(selected_latents.shape)[1]]
+ else:
+ getShape = [-1] + self.dims
+ expanded_latents = selected_latents.reshape(getShape)
+ return {self.output_keys[0]: expanded_latents}