diff --git a/jointContribution/yinglong/README.md b/jointContribution/yinglong/README.md new file mode 100644 index 0000000000..35a9f05995 --- /dev/null +++ b/jointContribution/yinglong/README.md @@ -0,0 +1,77 @@ +# A Study of Data-driven Limited Area Model for Weather Forecasting + +Recently, artificial intelligence-based models for forecasting global weather have been rapidly developed. Most of the global models are trained on reanalysis datasets with a spatial resolution of 0.25◦ × 0.25◦. However, study on artificial intelligence-based limited area weather forecasting models is still limited. In this study, an artificial intelligence-based limited area weather forecasting model (YingLong) is developed. YingLong utilizes a parallel structure of global and local blocks to capture multiscale meteorological features. Its predictability on surface temperature, humidity and wind speed is comparable to the predictability of the dynamical limited area model WRF-ARW, but with a much faster running speed. YingLong is also applied to investigate the issues related to the lateral boundary condition of artificial intelligence-based limited area models. The difference between artificial intelligence-based limited area models and dynamical limited area models is also discussed. + +This code is the implementation of YingLong. We select the southeastern region of the United States, which is around the range of 80-110W, 30-42N, with 440 × 408 grid points in Lambert projection. + +
+ +
+ +## Installation + +### 1. Install PaddlePaddle + +Please install the 2.6.0 or develop version of PaddlePaddle according to your environment on the official website of [PaddlePaddle](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html). + +For example, if your environment is linux and CUDA 11.2, you can install PaddlePaddle by the following command. + +``` shell +python -m pip install paddlepaddle-gpu==2.6.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +After installation, run the following command to verify if PaddlePaddle has been successfully installed. + +``` shell +python -c "import paddle; paddle.utils.run_check()" +``` + +If `"PaddlePaddle is installed successfully! Let's start deep learning with PaddlePaddle now."` appears, to verify that the installation was successful. + +### 2. Install PaddleScience + +Clone the code of PaddleScience from [here](https://github.com/PaddlePaddle/PaddleScience.git) and install requirements. + +``` shell +git clone -b develop https://github.com/PaddlePaddle/PaddleScience.git +cd PaddleScience +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +export PYTHONPATH=$PWD +``` + +## Example Usage + +### 1. Download the data and model weights + +``` shell +cd examples/yinglong +wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/yinglong/western_valid_data.tar +tar -xvf western_valid_data.tar +wget https://paddle-org.bj.bcebos.com/paddlescience/datasets/yinglong/eastern_valid_data.tar +tar -xvf eastern_valid_data.tar +wget https://paddle-org.bj.bcebos.com/paddlescience/models/yinglong/inference.tar +tar -xvf inference.tar +``` + +### 2. Run the code + +The following code runs the YingLong model, and the model output will be saved in `outputs_yinglong_eastern(western)/result.npy`. + +``` shell +model pretrain +python -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1' /examples/train_pretrain_parallel.py + +model finetune +python -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1' /examples/train_finetune_parallel.py + +model inference +python /examples/inference.py +``` + +We also visualized the predicted wind speed at 10 meters above ground level, with an initial field of 0:00 on January 1, 2022. Click [eastern](https://paddle-org.bj.bcebos.com/paddlescience/docs/Yinglong/result_eastern.gif)/[western](https://paddle-org.bj.bcebos.com/paddlescience/docs/Yinglong/result_western.gif) to view the prediction results. + +## License + +YingLong was released by Shanghai Zhangjiang Institute of Mathematics, Baidu inc. + +The commercial use of these models is forbidden. diff --git a/jointContribution/yinglong/examples/inference.py b/jointContribution/yinglong/examples/inference.py new file mode 100644 index 0000000000..cb2b2f7cdc --- /dev/null +++ b/jointContribution/yinglong/examples/inference.py @@ -0,0 +1,223 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from typing import Tuple + +import h5py +import numpy as np +import paddle.distributed as dist +import utils as local_utils +import visualdl as vdl + +import ppsci +from ppsci.utils import config +from ppsci.utils import logger + + +def get_vis_datas( + file_path: str, + date_strings: Tuple[str, ...], + num_timestamps: int, + vars_channel: Tuple[int, ...], + img_h: int, + data_mean: np.ndarray, + data_std: np.ndarray, +): + _file = h5py.File(file_path, "r")["fields"] + data = [] + for date_str in date_strings: + hours_since_jan_01_epoch = local_utils.date_to_hours(date_str) + ic = int(hours_since_jan_01_epoch / 6) + data.append(_file[ic : ic + num_timestamps + 1, vars_channel, 0:img_h]) + data = np.asarray(data) + + vis_datas = {"input": (data[:, 0] - data_mean) / data_std} + for t in range(num_timestamps): + hour = (t + 1) * 6 + data_t = data[:, t + 1] + wind_data = [] + for i in range(data_t.shape[0]): + wind_data.append((data_t[i][0] ** 2 + data_t[i][1] ** 2) ** 0.5) + vis_datas[f"target_{hour}h"] = np.asarray(wind_data) + return vis_datas + + +def copy_cur_file(output_dir): + os.makedirs(output_dir, exist_ok=True) + cur_file_path = os.path.abspath(__file__) + dst_file_path = os.path.join(output_dir, os.path.basename(__file__)) + shutil.copy(cur_file_path, dst_file_path) + + +if __name__ == "__main__": + args = config.parse_args() + # set random seed for reproducibility + ppsci.utils.set_random_seed(1024) + # Initialize distributed environment + dist.init_parallel_env() + + # set dataset path + TRAIN_FILE_PATH = "../train_data" + VALID_FILE_PATH = "../test_data" + DATA_MEAN_PATH = "../stat/mean_crop.npy" + DATA_STD_PATH = "../stat/std_crop.npy" + DATA_TIME_MEAN_PATH = "../stat/time_mean_crop.npy" + + MERGE_WEIGHTS_M = "../stat/mwp67.npy" + MERGE_WEIGHTS_N = "../stat/nwp67.npy" + + MERGE_LABLE = True + + # set training hyper-parameters + NUM_TIMESTAMPS = 48 + input_keys = ("input",) + output_keys = tuple(f"output_{i}" for i in range(NUM_TIMESTAMPS)) + IMG_H, IMG_W = 440, 408 + # FourCastNet HRRR Crop use 24 atmospheric variable,their index in the dataset is from 0 to 23. + # The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000', + # 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500', 'v850', 'v1000', + # 'mslp', 'u10', 'v10', 't2m'. + VARS_CHANNEL = list(range(24)) + VARIABLE_DICT = { + "z50": 0, + "z500": 1, + "z850": 2, + "z1000": 3, + "t50": 4, + "t500": 5, + "t850": 6, + "t1000": 7, + "s50": 8, + "s500": 9, + "s850": 10, + "s1000": 11, + "u50": 12, + "u500": 13, + "u850": 14, + "u1000": 15, + "v50": 16, + "v500": 17, + "v850": 18, + "v1000": 19, + "mslp": 20, + "u10": 21, + "v10": 22, + "t2m": 23, + } + # set output directory + OUTPUT_DIR = ( + "../output/hrrr_time_embedding_merge_train" + if not args.output_dir + else args.output_dir + ) + PRETRAINED_MODEL_PATH = ( + "../output/hrrr_time_embedding_merge_train/checkpoints/latest" + ) + # initialize logger + logger.init_logger("ppsci", f"{OUTPUT_DIR}/infer.log", "info") + copy_cur_file(OUTPUT_DIR) + + vdl_writer = vdl.LogWriter(f"{OUTPUT_DIR}/vdl_no_weight") + + data_mean, data_std = local_utils.get_mean_std( + DATA_MEAN_PATH, DATA_STD_PATH, VARS_CHANNEL + ) + data_time_mean = local_utils.get_time_mean( + DATA_TIME_MEAN_PATH, IMG_H, IMG_W, VARS_CHANNEL + ) + data_time_mean_normalize = np.expand_dims( + (data_time_mean - data_mean) / data_std, 0 + ) + + # set train transforms + transforms = [ + {"SqueezeData": {}}, + {"CropData": {"xmin": (0, 0), "xmax": (IMG_H, IMG_W)}}, + {"Normalize": {"mean": data_mean, "std": data_std}}, + ] + + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "HRRRDataset", + "file_path": VALID_FILE_PATH, + "input_keys": input_keys, + "label_keys": output_keys, + "vars_channel": VARS_CHANNEL, + "transforms": transforms, + "num_label_timestamps": NUM_TIMESTAMPS, + "training": False, + "stride": 24, + "merge_label": MERGE_LABLE, + }, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": False, + }, + "batch_size": 1, + } + + # set metirc + metric = { + "MAE": ppsci.metric.MAE(keep_batch=True), + "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE( + std=data_std, + keep_batch=True, + variable_dict=VARIABLE_DICT, + ), + "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC( + mean=data_time_mean_normalize, + keep_batch=True, + variable_dict=VARIABLE_DICT, + ), + } + + # set model + model = ppsci.arch.AFNOAttnParallelNet( + input_keys, + output_keys, + img_size=(IMG_H, IMG_W), + in_channels=len(VARS_CHANNEL), + out_channels=len(VARS_CHANNEL), + num_timestamps=NUM_TIMESTAMPS, + attn_channel_ratio=0.25, + merge_label=MERGE_LABLE, + merge_weights_m=MERGE_WEIGHTS_M, + merge_weights_n=MERGE_WEIGHTS_N, + ) + + # set validator for testing + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + ppsci.loss.L2RelLoss(), + metric=metric, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + # directly evaluate pretrained model + solver = ppsci.solver.Solver( + model, + output_dir=OUTPUT_DIR, + validator=validator, + pretrained_model_path=PRETRAINED_MODEL_PATH, + compute_metric_by_batch=True, + eval_with_no_grad=True, + vdl_writer=vdl_writer, + ) + solver.eval() diff --git a/jointContribution/yinglong/examples/train_finetune_parallel.py b/jointContribution/yinglong/examples/train_finetune_parallel.py new file mode 100644 index 0000000000..77bb785c15 --- /dev/null +++ b/jointContribution/yinglong/examples/train_finetune_parallel.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from typing import Tuple + +import h5py +import numpy as np +import paddle.distributed as dist +import utils as local_utils + +import ppsci +from ppsci.utils import config +from ppsci.utils import logger + + +def get_vis_datas( + file_path: str, + date_strings: Tuple[str, ...], + num_timestamps: int, + vars_channel: Tuple[int, ...], + img_h: int, + data_mean: np.ndarray, + data_std: np.ndarray, +): + _file = h5py.File(file_path, "r")["fields"] + data = [] + for date_str in date_strings: + hours_since_jan_01_epoch = local_utils.date_to_hours(date_str) + ic = int(hours_since_jan_01_epoch / 6) + data.append(_file[ic : ic + num_timestamps + 1, vars_channel, 0:img_h]) + data = np.asarray(data) + + vis_datas = {"input": (data[:, 0] - data_mean) / data_std} + for t in range(num_timestamps): + hour = (t + 1) * 6 + data_t = data[:, t + 1] + wind_data = [] + for i in range(data_t.shape[0]): + wind_data.append((data_t[i][0] ** 2 + data_t[i][1] ** 2) ** 0.5) + vis_datas[f"target_{hour}h"] = np.asarray(wind_data) + return vis_datas + + +def copy_cur_file(output_dir): + os.makedirs(output_dir, exist_ok=True) + cur_file_path = os.path.abspath(__file__) + dst_file_path = os.path.join(output_dir, os.path.basename(__file__)) + shutil.copy(cur_file_path, dst_file_path) + + +if __name__ == "__main__": + args = config.parse_args() + # set random seed for reproducibility + ppsci.utils.set_random_seed(1024) + # Initialize distributed environment + dist.init_parallel_env() + + # set dataset path + TRAIN_FILE_PATH = "../train_data" + VALID_FILE_PATH = "../test_data" + DATA_MEAN_PATH = "../stat/mean_crop.npy" + DATA_STD_PATH = "../stat/std_crop.npy" + DATA_TIME_MEAN_PATH = "../stat/time_mean_crop.npy" + MERGE_WEIGHTS_M = "../stat/mwp67.npy" + MERGE_WEIGHTS_N = "../stat/nwp67.npy" + + MERGE_LABLE = True + + # set training hyper-parameters + NUM_TIMESTAMPS = 2 if not args.num_timestamps else args.num_timestamps + input_keys = ("input",) + output_keys = tuple(f"output_{i}" for i in range(NUM_TIMESTAMPS)) + IMG_H, IMG_W = 440, 408 + EPOCHS = 15 if not args.epochs else args.epochs + # FourCastNet HRRR Crop use 24 atmospheric variable,their index in the dataset is from 0 to 23. + # The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000', + # 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500', 'v850', 'v1000', + # 'mslp', 'u10', 'v10', 't2m'. + VARS_CHANNEL = list(range(24)) + # set output directory + OUTPUT_DIR = ( + f"../output/hrrr_time_embedding_merge_train_finetune_{NUM_TIMESTAMPS}" + if not args.output_dir + else args.output_dir + ) + PRETRAINED_MODEL_PATH = ( + "../output/hrrr_time_embedding_merge_train/checkpoints/latest" + ) + + copy_cur_file(OUTPUT_DIR) + # initialize logger + logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + + data_mean, data_std = local_utils.get_mean_std( + DATA_MEAN_PATH, DATA_STD_PATH, VARS_CHANNEL + ) + data_time_mean = local_utils.get_time_mean( + DATA_TIME_MEAN_PATH, IMG_H, IMG_W, VARS_CHANNEL + ) + data_time_mean_normalize = np.expand_dims( + (data_time_mean - data_mean) / data_std, 0 + ) + + # set train transforms + transforms = [ + {"SqueezeData": {}}, + {"CropData": {"xmin": (0, 0), "xmax": (IMG_H, IMG_W)}}, + {"Normalize": {"mean": data_mean, "std": data_std}}, + ] + # set train dataloader config + train_dataloader_cfg = { + "dataset": { + "name": "HRRRDataset", + "file_path": TRAIN_FILE_PATH, + "input_keys": input_keys, + "label_keys": output_keys, + "vars_channel": VARS_CHANNEL, + "num_label_timestamps": NUM_TIMESTAMPS, + "transforms": transforms, + "merge_label": MERGE_LABLE, + }, + "sampler": { + "name": "BatchSampler", + "drop_last": True, + "shuffle": True, + }, + "batch_size": 2, + "num_workers": 8, + } + # set constraint + sup_constraint = ppsci.constraint.SupervisedConstraint( + train_dataloader_cfg, + ppsci.loss.L2RelLoss(), + name="Sup", + ) + constraint = {sup_constraint.name: sup_constraint} + + # set iters_per_epoch by dataloader length + ITERS_PER_EPOCH = len(sup_constraint.data_loader) + + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "HRRRDataset", + "file_path": VALID_FILE_PATH, + "input_keys": input_keys, + "label_keys": output_keys, + "vars_channel": VARS_CHANNEL, + "transforms": transforms, + "num_label_timestamps": NUM_TIMESTAMPS, + "training": False, + }, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": False, + }, + "batch_size": 8, + } + + # # set metirc + metric = { + "MAE": ppsci.metric.MAE(keep_batch=True), + "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE( + std=data_std, + keep_batch=True, + variable_dict={"u10": 20, "v10": 21}, + ), + "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC( + mean=data_time_mean_normalize, + keep_batch=True, + variable_dict={"u10": 20, "v10": 21}, + ), + } + + # # set validator + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + ppsci.loss.L2RelLoss(), + metric=metric, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + # set model + model = ppsci.arch.AFNOAttnParallelNet( + input_keys, + output_keys, + img_size=(IMG_H, IMG_W), + in_channels=len(VARS_CHANNEL), + out_channels=len(VARS_CHANNEL), + attn_channel_ratio=0.25, + num_timestamps=NUM_TIMESTAMPS, + use_recompute=True, + merge_label=MERGE_LABLE, + merge_weights_m=MERGE_WEIGHTS_M, + merge_weights_n=MERGE_WEIGHTS_N, + ) + + # init optimizer and lr scheduler + lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine( + EPOCHS, + ITERS_PER_EPOCH, + 1e-5, + by_epoch=True, + )() + optimizer = ppsci.optimizer.Adam(lr_scheduler)((model,)) + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + OUTPUT_DIR, + optimizer, + lr_scheduler, + EPOCHS, + ITERS_PER_EPOCH, + eval_during_train=True, + validator=validator, + pretrained_model_path=PRETRAINED_MODEL_PATH, + compute_metric_by_batch=True, + eval_with_no_grad=True, + ) + # solver.model = ppsci.arch.convert_linear_layer_to_lora(solver.model, r=128) + solver.train() diff --git a/jointContribution/yinglong/examples/train_pretrain_parallel.py b/jointContribution/yinglong/examples/train_pretrain_parallel.py new file mode 100644 index 0000000000..3ba1b4c467 --- /dev/null +++ b/jointContribution/yinglong/examples/train_pretrain_parallel.py @@ -0,0 +1,230 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import warnings + +import examples.fourcastnet_hrrr.utils as local_utils +import numpy as np +import paddle.distributed as dist + +import ppsci +from ppsci.utils import config +from ppsci.utils import logger + +warnings.filterwarnings("ignore") + + +def copy_cur_file(output_dir): + os.makedirs(output_dir, exist_ok=True) + cur_file_path = os.path.abspath(__file__) + dst_file_path = os.path.join(output_dir, os.path.basename(__file__)) + shutil.copy(cur_file_path, dst_file_path) + + +if __name__ == "__main__": + args = config.parse_args() + # set random seed for reproducibility + ppsci.utils.set_random_seed(1024) + # Initialize distributed environment + dist.init_parallel_env() + + # set dataset path + TRAIN_FILE_PATH = "../train_data" + VALID_FILE_PATH = "../test_data" + DATA_MEAN_PATH = "../stat/mean_crop.npy" + DATA_STD_PATH = "../stat/std_crop.npy" + DATA_TIME_MEAN_PATH = "../stat/time_mean_crop.npy" + + USE_EXTRA = False + if USE_EXTRA: + EXTRA_FILE_PATH = "/root/ssd3/datasets/hrrr_h5_crop_rad/" + EXTRA_DATA_MEAN_PATH = "/root/ssd3/datasets/hrrr_h5_crop_rad/stat/mean_crop.npy" + EXTRA_DATA_STD_PATH = "/root/ssd3/datasets/hrrr_h5_crop_rad/stat/std_crop.npy" + EXTRA_DATA_TIME_MEAN_PATH = ( + "/root/ssd3/datasets/hrrr_h5_crop_rad/stat/time_mean_crop.npy" + ) + EXTRA_VARS_CHANNEL = list(range(2)) + else: + EXTRA_FILE_PATH = None + EXTRA_VARS_CHANNEL = None + + # set training hyper-parameters + input_keys = ("input",) + output_keys = ("output",) + IMG_H, IMG_W = 440, 408 # for HRRR dataset croped data + EPOCHS = 30 if not args.epochs else args.epochs + # FourCastNet HRRR Crop use 24 atmospheric variable,their index in the dataset is from 0 to 23. + # The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000', + # 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500', 'v850', 'v1000', + # 'mslp', 'u10', 'v10', 't2m'. + VARS_CHANNEL = list(range(24)) + # set output directory + OUTPUT_DIR = ( + "../output/hrrr_time_embedding_merge_train" + if not args.output_dir + else args.output_dir + ) + # initialize logger + logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + + copy_cur_file(OUTPUT_DIR) + + data_mean, data_std = local_utils.get_mean_std( + DATA_MEAN_PATH, DATA_STD_PATH, VARS_CHANNEL + ) + data_time_mean = local_utils.get_time_mean( + DATA_TIME_MEAN_PATH, IMG_H, IMG_W, VARS_CHANNEL + ) + + if USE_EXTRA: + extra_data_mean, extra_data_std = local_utils.get_mean_std( + EXTRA_DATA_MEAN_PATH, EXTRA_DATA_STD_PATH, EXTRA_VARS_CHANNEL + ) + extra_data_time_mean = local_utils.get_time_mean( + EXTRA_DATA_TIME_MEAN_PATH, IMG_H, IMG_W, EXTRA_VARS_CHANNEL + ) + data_mean = np.concatenate((data_mean, extra_data_mean)) + data_std = np.concatenate((data_std, extra_data_std)) + data_time_mean = np.concatenate((data_time_mean, extra_data_time_mean)) + + data_time_mean_normalize = np.expand_dims( + (data_time_mean - data_mean) / data_std, 0 + ) + # set train transforms + transforms = [ + {"SqueezeData": {}}, + {"CropData": {"xmin": (0, 0), "xmax": (IMG_H, IMG_W)}}, + {"Normalize": {"mean": data_mean, "std": data_std}}, + ] + + # set train dataloader config + train_dataloader_cfg = { + "dataset": { + "name": "HRRRDataset", + "file_path": TRAIN_FILE_PATH, + "input_keys": input_keys, + "label_keys": output_keys, + "vars_channel": VARS_CHANNEL, + "transforms": transforms, + "extra_file_path": EXTRA_FILE_PATH, + "extra_vars_channel": EXTRA_VARS_CHANNEL, + }, + "sampler": { + "name": "BatchSampler", + "drop_last": True, + "shuffle": True, + }, + "batch_size": 2, + "num_workers": 8, + } + # set constraint + sup_constraint = ppsci.constraint.SupervisedConstraint( + train_dataloader_cfg, + ppsci.loss.L2RelLoss(), + name="Sup", + ) + constraint = {sup_constraint.name: sup_constraint} + + # set iters_per_epoch by dataloader length + ITERS_PER_EPOCH = len(sup_constraint.data_loader) + + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "HRRRDataset", + "file_path": VALID_FILE_PATH, + "input_keys": input_keys, + "label_keys": output_keys, + "vars_channel": VARS_CHANNEL, + "transforms": transforms, + }, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": False, + }, + "batch_size": 8, + } + + # set validator + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + ppsci.loss.L2RelLoss(), + metric={ + "MAE": ppsci.metric.MAE(keep_batch=True), + "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE( + std=data_std, + keep_batch=True, + variable_dict={"u10": 21, "v10": 22}, + ), + "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC( + mean=data_time_mean_normalize, + keep_batch=True, + variable_dict={"u10": 21, "v10": 22}, + ), + }, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + # set model + model = ppsci.arch.AFNOAttnParallelUNet( + input_keys, + output_keys, + img_size=(IMG_H, IMG_W), + in_channels=len(VARS_CHANNEL), + out_channels=len(VARS_CHANNEL), + attn_channel_ratio=[0.25] * 4 + [0.5] * 4 + [0.25] * 4, + ) + + model = ppsci.arch.AFNOAttnParallelNet( + input_keys, + output_keys, + img_size=(IMG_H, IMG_W), + in_channels=len(VARS_CHANNEL) + len(EXTRA_VARS_CHANNEL) + if USE_EXTRA + else len(VARS_CHANNEL), + out_channels=len(VARS_CHANNEL) + len(EXTRA_VARS_CHANNEL) + if USE_EXTRA + else len(VARS_CHANNEL), + attn_channel_ratio=0.25, + ) + + # init optimizer and lr scheduler + lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine( + EPOCHS, + ITERS_PER_EPOCH, + 5e-4, + by_epoch=True, + )() + optimizer = ppsci.optimizer.Adam(lr_scheduler)((model,)) + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + OUTPUT_DIR, + optimizer, + lr_scheduler, + EPOCHS, + ITERS_PER_EPOCH, + # eval_during_train=True, + # validator=validator, + compute_metric_by_batch=True, + eval_with_no_grad=True, + ) + # train model + solver.train() diff --git a/jointContribution/yinglong/examples/utils.py b/jointContribution/yinglong/examples/utils.py new file mode 100644 index 0000000000..500432d267 --- /dev/null +++ b/jointContribution/yinglong/examples/utils.py @@ -0,0 +1,35 @@ +from datetime import datetime +from typing import Optional +from typing import Tuple + +import numpy as np + + +def date_to_hours(date: str): + date_obj = datetime.strptime(date, "%Y-%m-%d %H:%M:%S") + day_of_year = date_obj.timetuple().tm_yday - 1 + hour_of_day = date_obj.timetuple().tm_hour + hours_since_jan_01_epoch = 24 * day_of_year + hour_of_day + return hours_since_jan_01_epoch + + +def get_mean_std(mean_path: str, std_path: str, vars_channel: Tuple[int, ...]): + data_mean = np.load(mean_path).reshape(-1, 1, 1).astype(np.float32) + data_mean = data_mean[vars_channel] + data_std = np.load(std_path).reshape(-1, 1, 1).astype(np.float32) + data_std = data_std[vars_channel] + return data_mean, data_std + + +def get_time_mean( + time_mean_path: str, + img_h: int, + img_w: int, + vars_channel: Optional[Tuple[int, ...]] = None, +): + time_mean = np.load(time_mean_path).astype(np.float32) + if vars_channel is not None: + time_mean = time_mean[vars_channel, :img_h, :img_w] + else: + time_mean = time_mean[:, :img_h, :img_w] + return time_mean diff --git a/jointContribution/yinglong/ppsci/__init__.py b/jointContribution/yinglong/ppsci/__init__.py new file mode 100644 index 0000000000..0a6929d99e --- /dev/null +++ b/jointContribution/yinglong/ppsci/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppsci import arch # isort:skip + +# from ppsci import autodiff # isort:skip +from ppsci import constraint # isort:skip +from ppsci import data # isort:skip + +# from ppsci import equation # isort:skip +# from ppsci import geometry # isort:skip +from ppsci import loss # isort:skip +from ppsci import metric # isort:skip +from ppsci import optimizer # isort:skip +from ppsci import utils # isort:skip + +# from ppsci import visualize # isort:skip +from ppsci import validate # isort:skip +from ppsci import solver # isort:skip + +__all__ = [ + "arch", + # "autodiff", + "constraint", + "data", + # "equation", + # "geometry", + "loss", + "metric", + "optimizer", + "utils", + # "visualize", + "validate", + "solver", +] diff --git a/jointContribution/yinglong/ppsci/arch/__init__.py b/jointContribution/yinglong/ppsci/arch/__init__.py new file mode 100644 index 0000000000..520773a7df --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/__init__.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +# from ppsci.arch.lora import convert_linear_layer_to_lora + +from ppsci.arch.mlp import MLP # isort:skip +# from ppsci.arch.embedding_koopman import LorenzEmbedding # isort:skip +# from ppsci.arch.embedding_koopman import RosslerEmbedding # isort:skip +# from ppsci.arch.embedding_koopman import CylinderEmbedding # isort:skip +# from ppsci.arch.physx_transformer import PhysformerGPT2 # isort:skip +# from ppsci.arch.model_list import ModelList # isort:skip +from ppsci.arch.afno import AFNONet # isort:skip +from ppsci.arch.afno import AFNOUNet # isort:skip +from ppsci.arch.afno import AFNOAttnNet # isort:skip +from ppsci.arch.afno import AFNOUNetWithAttn # isort:skip +from ppsci.arch.afno import AFNONetMultiInput # isort:skip +from ppsci.arch.afno import AFNOUNetMultiInput # isort:skip +from ppsci.arch.afno_attn_parallel import AFNOAttnParallelNet # isort:skip +from ppsci.arch.afno_attn_parallel import AFNOAttnParallelNetV2 # isort:skip +from ppsci.arch.afno_attn_parallel import AFNOAttnParallelNetV3 # isort:skip +from ppsci.arch.afno_attn_parallel import AFNOAttnParallelUNet # isort:skip +from ppsci.arch.afno import PrecipNet # isort:skip +from ppsci.utils import logger # isort:skip + + +__all__ = [ + "MLP", + "LorenzEmbedding", + "RosslerEmbedding", + "CylinderEmbedding", + "PhysformerGPT2", + "ModelList", + "AFNONet", + "AFNOUNet", + "AFNOAttnNet", + "AFNOUNetWithAttn", + "AFNONetMultiInput", + "AFNOUNetMultiInput", + "AFNOAttnParallelNet", + "AFNOAttnParallelNetV2", + "AFNOAttnParallelNetV3", + "AFNOAttnParallelUNet", + "PrecipNet", + "build_model", + "convert_linear_layer_to_lora", +] + + +def build_model(cfg): + """Build model + + Args: + cfg (AttrDict): Arch config. + + Returns: + nn.Layer: Model. + """ + cfg = copy.deepcopy(cfg) + arch_cls = cfg.pop("name") + arch = eval(arch_cls)(**cfg) + + logger.debug(str(arch)) + + return arch diff --git a/jointContribution/yinglong/ppsci/arch/activation.py b/jointContribution/yinglong/ppsci/arch/activation.py new file mode 100644 index 0000000000..9e59321afc --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/activation.py @@ -0,0 +1,48 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import paddle +import paddle.nn.functional as F +from paddle import nn + +act_func_dict = { + "elu": F.elu, + "relu": F.relu, + "selu": F.selu, + "gelu": F.gelu, + "sigmoid": F.sigmoid, + "silu": F.silu, + "sin": paddle.sin, + "cos": paddle.cos, + "swish": F.silu, + "tanh": F.tanh, + "identity": nn.Identity(), +} + + +def get_activation(act_name: str) -> Callable: + """Get activation function according to act_name. + + Args: + act_name (str): Name of activation, such as "tanh". + + Returns: + Callable: Paddle activation function. + """ + if act_name.lower() not in act_func_dict: + raise ValueError(f"act_name({act_name}) not found in act_func_dict") + + return act_func_dict[act_name.lower()] diff --git a/jointContribution/yinglong/ppsci/arch/afno.py b/jointContribution/yinglong/ppsci/arch/afno.py new file mode 100644 index 0000000000..f7039325e8 --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/afno.py @@ -0,0 +1,1858 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Code below is heavily based on [FourCastNet](https://github.com/NVlabs/FourCastNet) +""" + +from collections.abc import Callable +from functools import partial +from typing import Optional +from typing import Tuple + +import paddle +import paddle.fft +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed.fleet.utils import recompute + +from ppsci.arch import activation as act_mod +from ppsci.arch import base +from ppsci.utils import initializer + + +def drop_path( + x: paddle.Tensor, + drop_prob: float = 0.0, + training: bool = False, + scale_by_keep: bool = True, +) -> paddle.Tensor: + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + + Args: + x (paddle.Tensor): The tensor to apply. + drop_prob (float, optional): Drop paths probability. Defaults to 0.0. + training (bool, optional): Whether at training mode. Defaults to False. + scale_by_keep (bool, optional): Whether upscale the output. Defaults to True. + + Returns: + paddle.Tensor: Output tensor after apply dropout. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = paddle.full(shape, keep_prob, x.dtype) + random_tensor = paddle.bernoulli(random_tensor) + if keep_prob > 0.0 and scale_by_keep: + random_tensor = random_tensor / keep_prob + return x * random_tensor + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + drop_prob (float, optional): Drop paths probability. Defaults to 0.0. + scale_by_keep (bool, optional): Whether upscale the output. Defaults to True. + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class PeriodicPad2d(nn.Layer): + """Pad longitudinal (left-right) circular and pad latitude (top-bottom) with zeros. + + Args: + pad (int): Number of pad. + """ + + def __init__(self, pad: int): + super(PeriodicPad2d, self).__init__() + self.pad = pad + + def forward(self, x): + # pad left and right circular + out = F.pad(x, (self.pad, self.pad, 0, 0), mode="circular") + # pad top and bottom zeros + out = F.pad( + out, + (0, 0, 0, 0, self.pad, self.pad, 0, 0), + mode="constant", + value=0, + ) + return out + + +class MLP(nn.Layer): + """Multi layer perceptron module used in Transformer. + + Args: + in_features (int): Number of the input features. + hidden_features (Optional[int]): Number of the hidden size. Defaults to None. + out_features (Optional[int]): Number of the output features. Defaults to None. + activation (str, optional): Name of activation function. Defaults to "gelu". + drop (float, optional): Probability of dropout the units. Defaults to 0.0. + """ + + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + activation: str = "gelu", + drop: float = 0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_mod.get_activation(activation) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class AFNO2D(nn.Layer): + """2D Adaptive Fourier Neural Operators. + + Args: + hidden_size (int): Number of hidden size. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + hidden_size_factor (int, optional): The factor of hidden size. Defaults to 1. + scale (float, optional): The scale factor of the parameter when initialization. Defaults to 0.02. + """ + + def __init__( + self, + hidden_size: int, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + hidden_size_factor: int = 1, + scale: float = 0.02, + ): + super().__init__() + if hidden_size % num_blocks != 0: + raise ValueError( + f"hidden_size({hidden_size}) should be divisble by num_blocks({num_blocks})." + ) + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = scale + + self.w1 = self.create_parameter( + shape=( + 2, + self.num_blocks, + self.block_size, + self.block_size * self.hidden_size_factor, + ), + default_initializer=nn.initializer.Normal(std=self.scale), + ) + self.b1 = self.create_parameter( + shape=(2, self.num_blocks, self.block_size * self.hidden_size_factor), + default_initializer=nn.initializer.Normal(std=self.scale), + ) + self.w2 = self.create_parameter( + shape=( + 2, + self.num_blocks, + self.block_size * self.hidden_size_factor, + self.block_size, + ), + default_initializer=nn.initializer.Normal(std=self.scale), + ) + self.b2 = self.create_parameter( + shape=(2, self.num_blocks, self.block_size), + default_initializer=nn.initializer.Normal(std=self.scale), + ) + + def forward(self, x): + bias = x + + B, H, W, C = x.shape + + x = paddle.fft.rfft2(x, axes=(1, 2), norm="ortho") + x = x.reshape((B, H, W // 2 + 1, self.num_blocks, self.block_size)) + + o1_shape = ( + B, + H, + W // 2 + 1, + self.num_blocks, + self.block_size * self.hidden_size_factor, + ) + o1_real = paddle.zeros(o1_shape) + o1_imag = paddle.zeros(o1_shape) + o2_real = paddle.zeros(x.shape) + o2_imag = paddle.zeros(x.shape) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + st, end = total_modes - kept_modes, total_modes + kept_modes + + o1_real[:, st:end, :kept_modes] = F.relu( + paddle.einsum( + "xyzbi,bio->xyzbo", + x[:, st:end, :kept_modes].real(), + self.w1[0], + ) + - paddle.einsum( + "xyzbi,bio->xyzbo", + x[:, st:end, :kept_modes].imag(), + self.w1[1], + ) + + self.b1[0] + ) + + o1_imag[:, st:end, :kept_modes] = F.relu( + paddle.einsum( + "xyzbi,bio->xyzbo", + x[:, st:end, :kept_modes].imag(), + self.w1[0], + ) + + paddle.einsum( + "xyzbi,bio->xyzbo", + x[:, st:end, :kept_modes].real(), + self.w1[1], + ) + + self.b1[1] + ) + + o2_real[:, st:end, :kept_modes] = ( + paddle.einsum( + "xyzbi,bio->xyzbo", + o1_real[:, st:end, :kept_modes], + self.w2[0], + ) + - paddle.einsum( + "xyzbi,bio->xyzbo", + o1_imag[:, st:end, :kept_modes], + self.w2[1], + ) + + self.b2[0] + ) + + o2_imag[:, st:end, :kept_modes] = ( + paddle.einsum( + "xyzbi,bio->xyzbo", + o1_imag[:, st:end, :kept_modes], + self.w2[0], + ) + + paddle.einsum( + "xyzbi,bio->xyzbo", + o1_real[:, st:end, :kept_modes], + self.w2[1], + ) + + self.b2[1] + ) + + x = paddle.stack([o2_real, o2_imag], axis=-1) + x = F.softshrink(x, threshold=self.sparsity_threshold) + x = paddle.as_complex(x) + x = x.reshape((B, H, W // 2 + 1, C)) + x = paddle.fft.irfft2(x, s=(H, W), axes=(1, 2), norm="ortho") + + return x + bias + + +class Block(nn.Layer): + """AFNO network block. + + Args: + dim (int): The input tensor dimension. + mlp_ratio (float, optional): The ratio used in MLP. Defaults to 4.0. + drop (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + activation (str, optional): Name of activation function. Defaults to "gelu". + norm_layer (nn.Layer, optional): Class of norm layer. Defaults to nn.LayerNorm. + double_skip (bool, optional): Whether use double skip. Defaults to True. + num_blocks (int, optional): The number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + """ + + def __init__( + self, + dim: int, + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_path: float = 0.0, + activation: str = "gelu", + norm_layer: nn.Layer = nn.LayerNorm, + double_skip: bool = True, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.filter = AFNO2D( + dim, num_blocks, sparsity_threshold, hard_thresholding_fraction + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + activation=activation, + drop=drop, + ) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class Attention(nn.Layer): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + # B= paddle.shape(x)[0] + N, C = x.shape[1:] + qkv = ( + self.qkv(x) + .reshape((-1, N, 3, self.num_heads, C // self.num_heads)) + .transpose((2, 0, 3, 1, 4)) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block_attention(nn.Layer): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + activation="gelu", + norm_layer="nn.LayerNorm", + epsilon=1e-5, + ): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm1 = norm_layer(dim) + else: + raise TypeError("The norm_layer must be str or paddle.nn.layer.Layer class") + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm2 = norm_layer(dim) + else: + raise TypeError("The norm_layer must be str or paddle.nn.layer.Layer class") + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + activation=activation, + drop=drop, + ) + + def forward(self, x): + B, H, W, C = x.shape + x = x.reshape([B, H * W, C]) + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.reshape([B, H, W, C]) + return x + + +class PatchEmbed(nn.Layer): + """Patch embedding module. + + Args: + img_size (Tuple[int, ...], optional): Image size. Defaults to (224, 224). + patch_size (Tuple[int, ...], optional): Patch size. Defaults to (16, 16). + in_channels (int, optional): The input tensor channels. Defaults to 3. + embed_dim (int, optional): The output tensor channels. Defaults to 768. + """ + + def __init__( + self, + img_size: Tuple[int, ...] = (224, 224), + patch_size: Tuple[int, ...] = (16, 16), + in_channels: int = 3, + embed_dim: int = 768, + ): + super().__init__() + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2D( + in_channels, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + _, _, H, W = x.shape + if not (H == self.img_size[0] and W == self.img_size[1]): + raise ValueError( + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + ) + x = self.proj(x).flatten(2).transpose((0, 2, 1)) + return x + + +class AFNONet(base.Arch): + """Adaptive Fourier Neural Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + use_recompute=False, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.use_recompute = use_recompute + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.blocks = nn.LayerList( + [ + Block( + dim=embed_dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + for block in self.blocks: + # x = block(x) + if not self.use_recompute: + x = block(x) + else: + x = recompute(block, x) + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class PrecipNet(base.Arch): + """Precipitation Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + wind_model (base.Arch): Wind model. + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 1. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> wind_model = ppsci.arch.AFNONet(("input", ), ("output", )) + >>> model = ppsci.arch.PrecipNet(("input", ), ("output", ), wind_model) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + wind_model: base.Arch, + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 1, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps=1, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.backbone = AFNONet( + ("input",), + ("output",), + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + depth=depth, + mlp_ratio=mlp_ratio, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + num_blocks=num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + self.ppad = PeriodicPad2d(1) + self.conv = nn.Conv2D( + self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=0 + ) + self.act = nn.ReLU() + self.apply(self._init_weights) + self.wind_model = wind_model + self.wind_model.eval() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + x = self.backbone.forward_tensor(x) + x = self.ppad(x) + x = self.conv(x) + x = self.act(x) + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + input_wind = x + y = [] + for _ in range(self.num_timestamps): + with paddle.no_grad(): + out_wind = self.wind_model.forward_tensor(input_wind) + out = self.forward_tensor(out_wind) + y.append(out) + input_wind = out_wind + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class PatchMerging(nn.Layer): + r"""Patch Merging Layer. + + Args: + dim (int): Number of input channels. + norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + B, H, W, C = x.shape + if H % 2 != 0 or W % 2 != 0: + pad = [0, 0, 0, H % 2, 0, W % 2, 0, 0] + x = F.pad(x, pad, data_format="NHWC") + B, H, W, C = x.shape + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + x = x.reshape([B, H // 2, W // 2, 2 * C]) + + return x + + +class UpSample(nn.Layer): + r"""upsample layer.""" + + def __init__(self, h, w, input_dim, output_dim, norm_layer=nn.LayerNorm): + super().__init__() + self.h, self.w = h, w + self.input_dim = input_dim + self.output_dim = output_dim + + self.linear1 = nn.Linear(input_dim, output_dim * 4, bias_attr=False) + self.linear2 = nn.Linear(output_dim, output_dim, bias_attr=False) + self.norm = norm_layer(output_dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + # H_, W = self.input_resolution + B, H, W, C = x.shape + x = self.linear1(x) + x = x.reshape([B, H, W, 2, 2, C // 2]) + x = x.transpose([0, 1, 3, 2, 4, 5]) + x = x.reshape([B, H * 2, W * 2, C // 2]) + + x = x[:, : self.h, : self.w, :] + + x = self.norm(x) + x = self.linear2(x) + + return x + + +class AFNOUNet(base.Arch): + """Adaptive Fourier Unet. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depths=[2, 4, 4, 2], + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + linear_head=True, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.linear_head = linear_head + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + depth = sum(depths) + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + self.num_layers = len(depths) + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.layers = nn.LayerList() + for i_layer in range(self.num_layers): + layer_i = nn.Sequential() + for j in range(depths[i_layer]): + if i_layer >= self.num_layers // 2: + dim = embed_dim * 2 ** (self.num_layers - 1 - i_layer) + else: + dim = embed_dim * 2**i_layer + layer = Block( + dim=dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[j] + if i_layer == 0 + else dpr[sum(depths[:i_layer]) + j], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + layer_i.add_sublayer("block{}".format(j), layer) + self.layers.append(layer_i) + + self.patch_merger = PatchMerging(embed_dim, norm_layer) + self.upsample = UpSample(self.h, self.w, embed_dim * 2, embed_dim, norm_layer) + + self.norm = norm_layer(embed_dim) + if linear_head: + self.head = nn.Linear( + embed_dim * 2, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + else: + self.head = nn.Sequential( + ( + "conv1", + nn.Conv2DTranspose( + embed_dim * 2, + self.out_channels * 16, + kernel_size=(2, 2), + stride=(2, 2), + ), + ), + ("act1", nn.Tanh()), + ( + "conv2", + nn.Conv2DTranspose( + self.out_channels * 16, + self.out_channels * 4, + kernel_size=(2, 2), + stride=(2, 2), + ), + ), + ("act2", nn.Tanh()), + ( + "conv3", + nn.Conv2DTranspose( + self.out_channels * 4, + self.out_channels, + kernel_size=(2, 2), + stride=(2, 2), + ), + ), + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + layer_0_out = None + for i, layer in enumerate(self.layers): + if i == 0: + x = layer(x) + layer_0_out = x + x = self.patch_merger(x) + elif i == 1: + x = layer(x) + elif i == 2: + x = layer(x) + elif i == 3: + x = self.upsample(x) + x = layer(x) + x = paddle.concat([x, layer_0_out], axis=3) + + if self.linear_head: + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + else: + B, H, W, C = x.shape + x = x.transpose([0, 3, 1, 2]) + x = self.head(x) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class AFNOAttnNet(base.Arch): + """Adaptive Fourier Neural Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depths: Tuple[int, ...] = [4, 8], + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + num_heads: int = 12, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + depth = sum(depths) + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.blocks = nn.LayerList( + [ + Block( + dim=embed_dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(depths[0]) + ] + + [ + Block_attention( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[depths[0] + i], + norm_layer=norm_layer, + ) + for i in range(depths[1]) + ] + ) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + for block in self.blocks: + x = block(x) + + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class AFNOUNetWithAttn(base.Arch): + """Adaptive Fourier Unet. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depths=[4, 2, 2, 4], + attn_layer=[False, True, True, True], + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.attn_layer = attn_layer + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + depth = sum(depths) + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + self.num_layers = len(depths) + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.layers = nn.LayerList() + for i_layer in range(self.num_layers): + layer_i = nn.Sequential() + for j in range(depths[i_layer]): + if i_layer >= self.num_layers // 2: + dim = embed_dim * 2 ** (self.num_layers - 1 - i_layer) + else: + dim = embed_dim * 2**i_layer + if self.attn_layer[i_layer] is False: + layer = Block( + dim=dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[j] + if i_layer == 0 + else dpr[sum(depths[:i_layer]) + j], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + else: + layer = Block_attention( + dim=dim, + num_heads=12, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[j] + if i_layer == 0 + else dpr[sum(depths[:i_layer]) + j], + norm_layer=norm_layer, + ) + layer_i.add_sublayer("block{}".format(j), layer) + self.layers.append(layer_i) + + self.patch_merger = PatchMerging(embed_dim, norm_layer) + self.upsample = UpSample(self.h, self.w, embed_dim * 2, embed_dim, norm_layer) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim * 2, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + layer_0_out = None + for i, layer in enumerate(self.layers): + if i == 0: + x = layer(x) + layer_0_out = x + x = self.patch_merger(x) + elif i == 1: + x = layer(x) + elif i == 2: + x = layer(x) + elif i == 3: + x = self.upsample(x) + x = layer(x) + x = paddle.concat([x, layer_0_out], axis=3) + + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class AFNONetMultiInput(base.Arch): + """Adaptive Fourier Neural Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels * len(input_keys), + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.blocks = nn.LayerList( + [ + Block( + dim=embed_dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + for block in self.blocks: + x = block(x) + + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys, axis=1) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = paddle.concat( + [ + out, + input[ + :, + : -self.out_channels, + ], + ], + axis=1, + ) + + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class AFNOUNetMultiInput(base.Arch): + """Adaptive Fourier Unet. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depths=[2, 4, 4, 2], + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + linear_head=True, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.linear_head = linear_head + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels * len(self.input_keys), + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + depth = sum(depths) + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + self.num_layers = len(depths) + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.layers = nn.LayerList() + for i_layer in range(self.num_layers): + layer_i = nn.Sequential() + for j in range(depths[i_layer]): + if i_layer >= self.num_layers // 2: + dim = embed_dim * 2 ** (self.num_layers - 1 - i_layer) + else: + dim = embed_dim * 2**i_layer + layer = Block( + dim=dim, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[j] + if i_layer == 0 + else dpr[sum(depths[:i_layer]) + j], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + layer_i.add_sublayer("block{}".format(j), layer) + self.layers.append(layer_i) + + self.patch_merger = PatchMerging(embed_dim, norm_layer) + self.upsample = UpSample(self.h, self.w, embed_dim * 2, embed_dim, norm_layer) + + self.norm = norm_layer(embed_dim) + if linear_head: + self.head = nn.Linear( + embed_dim * 2, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + else: + self.head = nn.Sequential( + ( + "conv1", + nn.Conv2DTranspose( + embed_dim * 2, + self.out_channels * 16, + kernel_size=(2, 2), + stride=(2, 2), + ), + ), + ("act1", nn.Tanh()), + ( + "conv2", + nn.Conv2DTranspose( + self.out_channels * 16, + self.out_channels * 4, + kernel_size=(2, 2), + stride=(2, 2), + ), + ), + ("act2", nn.Tanh()), + ( + "conv3", + nn.Conv2DTranspose( + self.out_channels * 4, + self.out_channels, + kernel_size=(2, 2), + stride=(2, 2), + ), + ), + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + layer_0_out = None + for i, layer in enumerate(self.layers): + if i == 0: + x = layer(x) + layer_0_out = x + x = self.patch_merger(x) + elif i == 1: + x = layer(x) + elif i == 2: + x = layer(x) + elif i == 3: + x = self.upsample(x) + x = layer(x) + x = paddle.concat([x, layer_0_out], axis=3) + + if self.linear_head: + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + else: + B, H, W, C = x.shape + x = x.transpose([0, 3, 1, 2]) + x = self.head(x) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys, axis=1) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = paddle.concat( + [ + out, + input[ + :, + : -self.out_channels, + ], + ], + axis=1, + ) + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y diff --git a/jointContribution/yinglong/ppsci/arch/afno_attn_parallel.py b/jointContribution/yinglong/ppsci/arch/afno_attn_parallel.py new file mode 100644 index 0000000000..af80e5fa11 --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/afno_attn_parallel.py @@ -0,0 +1,1638 @@ +from functools import partial +from typing import Tuple + +import numpy as np +import paddle +import paddle.fft +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed.fleet.utils import recompute +from paddle.nn.initializer import Constant +from paddle.nn.initializer import Normal +from paddle.nn.initializer import TruncatedNormal + +from ppsci.arch import base +from ppsci.utils import initializer + +from .afno import AFNO2D +from .afno import MLP +from .afno import DropPath +from .afno import PatchEmbed +from .afno import PatchMerging +from .afno import UpSample +from .time_embedding import TimeFeatureEmbedding + +trunc_normal_ = TruncatedNormal(std=0.02) +normal_ = Normal +zeros_ = Constant(value=0.0) +ones_ = Constant(value=1.0) + + +def to_2tuple(x): + return tuple([x] * 2) + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.reshape([B, H // window_size, window_size, W // window_size, window_size, C]) + windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C]) + return windows + + +def window_reverse(windows, window_size, H, W, C): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + x = windows.reshape( + [-1, H // window_size, W // window_size, window_size, window_size, C] + ) + x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, H, W, C]) + return x + + +class WindowAttention(nn.Layer): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + # 2*Wh-1 * 2*Ww-1, nH + self.relative_position_bias_table = self.create_parameter( + shape=((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads), + default_initializer=zeros_, + ) + self.add_parameter( + "relative_position_bias_table", self.relative_position_bias_table + ) + + # get pair-wise relative position index for each token inside the window + coords_h = paddle.arange(self.window_size[0]) + coords_w = paddle.arange(self.window_size[1]) + coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww + + coords_flatten_1 = coords_flatten.unsqueeze(axis=2) + coords_flatten_2 = coords_flatten.unsqueeze(axis=1) + relative_coords = coords_flatten_1 - coords_flatten_2 + + relative_coords = relative_coords.transpose([1, 2, 0]) # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table) + self.softmax = nn.Softmax(axis=-1) + + def eval( + self, + ): + # this is used to re-param swin for model export + relative_position_bias_table = self.relative_position_bias_table + window_size = self.window_size + index = self.relative_position_index.reshape([-1]) + + relative_position_bias = paddle.index_select( + relative_position_bias_table, index + ) + relative_position_bias = relative_position_bias.reshape( + [window_size[0] * window_size[1], window_size[0] * window_size[1], -1] + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.transpose( + [2, 0, 1] + ) # nH, Wh*Ww, Wh*Ww + relative_position_bias = relative_position_bias.unsqueeze(0) + self.register_buffer("relative_position_bias", relative_position_bias) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape([B_, N, 3, self.num_heads, C // self.num_heads]) + .transpose([2, 0, 3, 1, 4]) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = paddle.mm(q, k.transpose([0, 1, 3, 2])) + + if self.training or not hasattr(self, "relative_position_bias"): + index = self.relative_position_index.reshape([-1]) + + relative_position_bias = paddle.index_select( + self.relative_position_bias_table, index + ) + relative_position_bias = relative_position_bias.reshape( + [ + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ] + ) # Wh*Ww,Wh*Ww,nH + + relative_position_bias = relative_position_bias.transpose( + [2, 0, 1] + ) # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + else: + attn = attn + self.relative_position_bias + + if mask is not None: + nW = mask.shape[0] + attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N]) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.reshape([-1, self.num_heads, N, N]) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + # x = (attn @ v).transpose(1, 2).reshape([B_, N, C]) + x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([B_, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self): + return "dim={}, window_size={}, num_heads={}".format( + self.dim, self.window_size, self.num_heads + ) + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class Block(nn.Layer): + """AFNO network block. + + Args: + dim (int): The input tensor dimension. + mlp_ratio (float, optional): The ratio used in MLP. Defaults to 4.0. + drop (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + activation (str, optional): Name of activation function. Defaults to "gelu". + norm_layer (nn.Layer, optional): Class of norm layer. Defaults to nn.LayerNorm. + double_skip (bool, optional): Whether use double skip. Defaults to True. + num_blocks (int, optional): The number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + """ + + def __init__( + self, + dim: int, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + attn_channel_ratio=0.5, + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_path: float = 0.0, + activation: str = "gelu", + norm_layer: nn.Layer = nn.LayerNorm, + double_skip: bool = True, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.attn_channle_ratio = attn_channel_ratio + self.attn_dim = int(dim * attn_channel_ratio) + + self.norm1 = norm_layer(dim) + + if dim - self.attn_dim > 0: + self.filter = AFNO2D( + dim - self.attn_dim, + num_blocks, + sparsity_threshold, + hard_thresholding_fraction, + ) + if self.attn_dim > 0: + self.attn = WindowAttention( + self.attn_dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = paddle.zeros([1, Hp, Wp, 1], dtype="float32") # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + try: + img_mask[:, h, w, :] = cnt + except: + pass + + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.reshape( + [-1, self.window_size * self.window_size] + ) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + huns = -100.0 * paddle.ones_like(attn_mask) + attn_mask = huns * (attn_mask != 0).astype("float32") + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + activation=activation, + drop=drop, + ) + self.double_skip = double_skip + + def attn_forward(self, x): + B, H, W, C = x.shape + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, [0, pad_l, 0, pad_b, 0, pad_r, 0, pad_t]) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = paddle.roll( + x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.reshape( + [-1, self.window_size * self.window_size, C] + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C]) + shifted_x = window_reverse( + attn_windows, self.window_size, Hp, Wp, C + ) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = paddle.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), axis=(1, 2) + ) + else: + x = shifted_x + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :] + return x + + def afno_forward(self, x): + x = self.filter(x) + return x + + def forward(self, x): + B, H, W, C = x.shape + residual = x + x = self.norm1(x) + + if self.attn_dim == 0: + x = self.afno_forward(x) + elif self.attn_dim == self.dim: + x = self.attn_forward(x) + else: # self.attn_dim > 0 and self.attn_dim < self.dim + x_attn = x[:, :, :, : self.attn_dim] + x_afno = x[:, :, :, self.attn_dim :] + x_attn = self.attn_forward(x_attn) + x_afno = self.afno_forward(x_afno) + + x = paddle.concat([x_attn, x_afno], axis=-1) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class AFNOAttnParallelNet(base.Arch): + """Adaptive Fourier Neural Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + window_size=7, + num_heads=8, + attn_channel_ratio=0.5, + use_recompute=False, + merge_label=False, + merge_weights_n=None, + merge_weights_m=None, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.window_size = window_size + self.num_heads = num_heads + if not isinstance(attn_channel_ratio, list): + self.attn_channel_ratio = [attn_channel_ratio] * depth + else: + self.attn_channel_ratio = attn_channel_ratio + assert len(self.attn_channel_ratio) == depth + + self.use_recompute = use_recompute + self.merge_label = merge_label + if merge_label is True: + self.merge_weights_n = paddle.to_tensor( + np.load(merge_weights_n), dtype=paddle.float32 + ) + self.merge_weights_m = paddle.to_tensor( + np.load(merge_weights_m), dtype=paddle.float32 + ) + + self.merge_weights_n = self.merge_weights_n.unsqueeze(0).unsqueeze(0) + self.merge_weights_m = self.merge_weights_m.unsqueeze(0).unsqueeze(0) + + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + + self.time_embed = TimeFeatureEmbedding( + d_model=embed_dim, embed_type="fixed", freq="h" + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.blocks = nn.LayerList( + [ + Block( + dim=embed_dim, + input_resolution=(self.h, self.w), + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + attn_channel_ratio=self.attn_channel_ratio[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x, x_time): + B = x.shape[0] + x = self.patch_embed(x) + + x = x + self.pos_embed + self.time_embed(x_time, x.shape[1]) + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + + for block in self.blocks: + # x = block(x) + if not self.use_recompute: + x = block(x) + else: + x = recompute(block, x) + + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x, x_time): + if self._input_transform is not None: + x = self._input_transform(x) + x_tensor = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x_tensor + for i in range(self.num_timestamps): + out = self.forward_tensor(input, x_time[i]) + y.append(out) + if self.merge_label: + input = ( + self.merge_weights_m * out + + self.merge_weights_n * x[f"{self.input_keys[0]}_{i}_merge"] + ) + else: + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + # def forward(self, x): + # x_time = ['2020/01/02/0'] + # out = self.forward_tensor(x,x_time) + # return out + + +class AFNOAttnParallelUNet(base.Arch): + """Adaptive Fourier Neural Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depths=[4, 2, 2, 4], + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + window_size=7, + num_heads=8, + attn_channel_ratio=0.25, + use_recompute=False, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.window_size = window_size + self.num_heads = num_heads + depth = sum(depths) + if not isinstance(attn_channel_ratio, list): + self.attn_channel_ratio = [attn_channel_ratio] * depth + else: + self.attn_channel_ratio = attn_channel_ratio + assert len(self.attn_channel_ratio) == depth + + self.use_recompute = use_recompute + + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + self.num_layers = len(depths) + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + # layer0 + self.block0 = nn.Sequential() + for i in range(depths[0]): + block = Block( + dim=embed_dim, + input_resolution=(self.h, self.w), + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + attn_channel_ratio=self.attn_channel_ratio[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + self.block0.add_sublayer(str(i), block) + + # layer 1 + self.block1 = nn.Sequential() + self.patch_merger = PatchMerging(embed_dim, norm_layer) + for i in range(depths[1]): + i = depths[0] + i + block = Block( + dim=embed_dim * 2, + input_resolution=(self.h // 2, self.w // 2), + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + attn_channel_ratio=self.attn_channel_ratio[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + self.block1.add_sublayer(str(i), block) + + # layer 2 + self.block2 = nn.Sequential() + for i in range(depths[2]): + i = depths[0] + depths[1] + i + block = Block( + dim=embed_dim * 2, + input_resolution=(self.h // 2, self.w // 2), + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + attn_channel_ratio=self.attn_channel_ratio[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + self.block2.add_sublayer(str(i), block) + + # layer 3 + self.block3 = nn.Sequential() + self.upsample = UpSample(self.h, self.w, embed_dim * 2, embed_dim, norm_layer) + for i in range(depths[3]): + i = depths[0] + depths[1] + depths[2] + i + block = Block( + dim=embed_dim, + input_resolution=(self.h, self.w), + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + attn_channel_ratio=self.attn_channel_ratio[i], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + self.block3.add_sublayer(str(i), block) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim * 2, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + + if not self.use_recompute: + x0 = self.block0(x) + else: + x0 = recompute(self.block0, x) + + x1 = self.patch_merger(x0) + if not self.use_recompute: + x1 = self.block1(x1) + else: + x1 = recompute(self.block1, x1) + + if not self.use_recompute: + x2 = self.block2(x1) + else: + x2 = recompute(self.block2, x1) + + x3 = self.upsample(x2) + if not self.use_recompute: + x3 = self.block3(x3) + else: + x3 = recompute(self.block3, x3) + + x = paddle.concat([x0, x3], axis=-1) + + if not self.use_recompute: + x = self.head(x) + else: + x = recompute(self.head, x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class BlockV2(nn.Layer): + """AFNO network block. + + Args: + dim (int): The input tensor dimension. + mlp_ratio (float, optional): The ratio used in MLP. Defaults to 4.0. + drop (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + activation (str, optional): Name of activation function. Defaults to "gelu". + norm_layer (nn.Layer, optional): Class of norm layer. Defaults to nn.LayerNorm. + double_skip (bool, optional): Whether use double skip. Defaults to True. + num_blocks (int, optional): The number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + """ + + def __init__( + self, + dim: int, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_path: float = 0.0, + activation: str = "gelu", + norm_layer: nn.Layer = nn.LayerNorm, + double_skip: bool = True, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + + self.norm1 = norm_layer(dim) + + self.filter = AFNO2D( + dim, + num_blocks, + sparsity_threshold, + hard_thresholding_fraction, + ) + + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = paddle.zeros([1, Hp, Wp, 1], dtype="float32") # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + try: + img_mask[:, h, w, :] = cnt + except: + pass + + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.reshape( + [-1, self.window_size * self.window_size] + ) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + huns = -100.0 * paddle.ones_like(attn_mask) + attn_mask = huns * (attn_mask != 0).astype("float32") + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.reduce = nn.Linear(2 * dim, dim) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + activation=activation, + drop=drop, + ) + self.double_skip = double_skip + + def attn_forward(self, x): + B, H, W, C = x.shape + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, [0, pad_l, 0, pad_b, 0, pad_r, 0, pad_t]) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = paddle.roll( + x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.reshape( + [-1, self.window_size * self.window_size, C] + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C]) + shifted_x = window_reverse( + attn_windows, self.window_size, Hp, Wp, C + ) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = paddle.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), axis=(1, 2) + ) + else: + x = shifted_x + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :] + return x + + def afno_forward(self, x): + x = self.filter(x) + return x + + def forward(self, x): + residual = x + x = self.norm1(x) + + x_attn = self.attn_forward(x) + x_afno = self.afno_forward(x) + + x = paddle.concat([x_attn, x_afno], axis=-1) + x = self.reduce(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class AFNOAttnParallelNetV2(base.Arch): + """Adaptive Fourier Neural Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + window_size=7, + num_heads=8, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.window_size = window_size + self.num_heads = num_heads + + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.blocks = nn.LayerList( + [ + BlockV2( + dim=embed_dim, + input_resolution=(self.h, self.w), + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + + for block in self.blocks: + x = block(x) + + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y + + +class BlockV3(nn.Layer): + """AFNO network block. + + Args: + dim (int): The input tensor dimension. + mlp_ratio (float, optional): The ratio used in MLP. Defaults to 4.0. + drop (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + activation (str, optional): Name of activation function. Defaults to "gelu". + norm_layer (nn.Layer, optional): Class of norm layer. Defaults to nn.LayerNorm. + double_skip (bool, optional): Whether use double skip. Defaults to True. + num_blocks (int, optional): The number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + """ + + def __init__( + self, + dim: int, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_path: float = 0.0, + activation: str = "gelu", + norm_layer: nn.Layer = nn.LayerNorm, + double_skip: bool = True, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + + self.norm1 = norm_layer(dim) + + self.reduce_attn = nn.Linear(dim, dim // 4) + self.reduce_afno = nn.Linear(dim, dim // 4 * 3) + + self.filter = AFNO2D( + dim // 4 * 3, + num_blocks, + sparsity_threshold, + hard_thresholding_fraction, + ) + + self.attn = WindowAttention( + dim // 4, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = paddle.zeros([1, Hp, Wp, 1], dtype="float32") # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + try: + img_mask[:, h, w, :] = cnt + except: + pass + + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.reshape( + [-1, self.window_size * self.window_size] + ) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + huns = -100.0 * paddle.ones_like(attn_mask) + attn_mask = huns * (attn_mask != 0).astype("float32") + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + activation=activation, + drop=drop, + ) + self.double_skip = double_skip + + def attn_forward(self, x): + B, H, W, C = x.shape + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, [0, pad_l, 0, pad_b, 0, pad_r, 0, pad_t]) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = paddle.roll( + x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.reshape( + [-1, self.window_size * self.window_size, C] + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C]) + shifted_x = window_reverse( + attn_windows, self.window_size, Hp, Wp, C + ) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = paddle.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), axis=(1, 2) + ) + else: + x = shifted_x + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :] + return x + + def afno_forward(self, x): + x = self.filter(x) + return x + + def forward(self, x): + residual = x + x = self.norm1(x) + + x_attn = self.reduce_attn(x) + x_attn = self.attn_forward(x_attn) + + x_afno = self.reduce_afno(x) + x_afno = self.afno_forward(x_afno) + + x = paddle.concat([x_attn, x_afno], axis=-1) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = self.drop_path(x) + x = x + residual + return x + + +class AFNOAttnParallelNetV3(base.Arch): + """Adaptive Fourier Neural Network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("input",). + output_keys (Tuple[str, ...]): Name of output keys, such as ("output",). + img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440). + patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8). + in_channels (int, optional): The input tensor channels. Defaults to 20. + out_channels (int, optional): The output tensor channels. Defaults to 20. + embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768. + depth (int, optional): Number of transformer depth. Defaults to 12. + mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0. + drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0. + drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0. + num_blocks (int, optional): Number of blocks. Defaults to 8. + sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01. + hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0. + num_timestamps (int, optional): Number of timestamp. Defaults to 1. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.AFNONet(("input", ), ("output", )) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + img_size: Tuple[int, ...] = (720, 1440), + patch_size: Tuple[int, ...] = (8, 8), + in_channels: int = 20, + out_channels: int = 20, + embed_dim: int = 768, + depth: int = 12, + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 8, + sparsity_threshold: float = 0.01, + hard_thresholding_fraction: float = 1.0, + num_timestamps: int = 1, + window_size=7, + num_heads=8, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + self.img_size = img_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + self.num_blocks = num_blocks + self.num_timestamps = num_timestamps + self.window_size = window_size + self.num_heads = num_heads + + norm_layer = partial(nn.LayerNorm, epsilon=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + data = paddle.zeros((1, num_patches, embed_dim)) + data = initializer.trunc_normal_(data, std=0.02) + self.pos_embed = paddle.create_parameter( + shape=data.shape, + dtype=data.dtype, + default_initializer=nn.initializer.Assign(data), + ) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] + + self.h = img_size[0] // self.patch_size[0] + self.w = img_size[1] // self.patch_size[1] + + self.blocks = nn.LayerList( + [ + BlockV3( + dim=embed_dim, + input_resolution=(self.h, self.w), + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + num_blocks=self.num_blocks, + sparsity_threshold=sparsity_threshold, + hard_thresholding_fraction=hard_thresholding_fraction, + ) + for i in range(depth) + ] + ) + + self.norm = norm_layer(embed_dim) + self.head = nn.Linear( + embed_dim, + self.out_channels * self.patch_size[0] * self.patch_size[1], + bias_attr=False, + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + initializer.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + initializer.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + initializer.ones_(m.weight) + initializer.zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + initializer.conv_init_(m) + + def forward_tensor(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + x = x.reshape((B, self.h, self.w, self.embed_dim)) + + for block in self.blocks: + x = block(x) + + x = self.head(x) + + b = x.shape[0] + p1 = self.patch_size[0] + p2 = self.patch_size[1] + h = self.img_size[0] // self.patch_size[0] + w = self.img_size[1] // self.patch_size[1] + c_out = x.shape[3] // (p1 * p2) + x = x.reshape((b, h, w, p1, p2, c_out)) + x = x.transpose((0, 5, 1, 3, 2, 4)) + x = x.reshape((b, c_out, h * p1, w * p2)) + + return x + + def split_to_dict( + self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...] + ): + return {key: data_tensors[i] for i, key in enumerate(keys)} + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + y = [] + input = x + for _ in range(self.num_timestamps): + out = self.forward_tensor(input) + y.append(out) + input = out + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(y) + return y diff --git a/jointContribution/yinglong/ppsci/arch/base.py b/jointContribution/yinglong/ppsci/arch/base.py new file mode 100644 index 0000000000..a9b13f4942 --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/base.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable +from typing import Dict +from typing import Tuple + +import numpy as np +import paddle +from paddle import nn + +from ppsci.utils import logger + + +class Arch(nn.Layer): + """Base class for Network.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._input_transform = None + self._output_transform = None + + def forward(self, *args, **kwargs): + raise NotImplementedError("Arch.forward is not implemented") + + @property + def num_params(self) -> int: + """Return number of parameters within network. + + Returns: + int: Number of parameters. + """ + num = 0 + for name, param in self.named_parameters(): + if hasattr(param, "shape"): + num += np.prod(list(param.shape)) + else: + logger.warning(f"{name} has no attribute 'shape'") + return num + + def concat_to_tensor( + self, data_dict: Dict[str, paddle.Tensor], keys: Tuple[str, ...], axis=-1 + ) -> Tuple[paddle.Tensor, ...]: + """Concatenate tensors from dict in the order of given keys. + + Args: + data_dict (Dict[str, paddle.Tensor]): Dict contains tensor. + keys (Tuple[str, ...]): Keys tensor fetched from. + axis (int, optional): Axis concate at. Defaults to -1. + + Returns: + Tuple[paddle.Tensor, ...]: Concatenated tensor. + """ + if len(keys) == 1: + return data_dict[keys[0]] + data = [data_dict[key] for key in keys] + return paddle.concat(data, axis) + def concat_to_list( + self, data_dict: Dict[str, paddle.Tensor], keys: Tuple[str, ...], axis=-1 + ) -> Tuple[paddle.Tensor, ...]: + """Concatenate tensors from dict in the order of given keys. + + Args: + data_dict (Dict[str, paddle.Tensor]): Dict contains tensor. + keys (Tuple[str, ...]): Keys tensor fetched from. + axis (int, optional): Axis concate at. Defaults to -1. + + Returns: + Tuple[paddle.Tensor, ...]: Concatenated tensor. + """ + if len(keys) == 1: + return data_dict[keys[0]] + data = [data_dict[key] for key in keys] + return data + + def split_to_dict( + self, data_tensor: paddle.Tensor, keys: Tuple[str, ...], axis=-1 + ) -> Dict[str, paddle.Tensor]: + """Split tensor and wrap into a dict by given keys. + + Args: + data_tensor (Dict[str, paddle.Tensor]): Tensor to be split. + keys (Tuple[str, ...]): Keys tensor mapping to. + axis (int, optional): Axis split at. Defaults to -1. + + Returns: + Dict[str, paddle.Tensor]: Dict contains tensor. + """ + if len(keys) == 1: + return {keys[0]: data_tensor} + data = paddle.split(data_tensor, len(keys), axis=axis) + return {key: data[i] for i, key in enumerate(keys)} + + def register_input_transform( + self, + transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]], + ): + """Register input transform. + + Args: + transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]): + Input transform of network, receive a single tensor dict and return a single tensor dict. + """ + self._input_transform = transform + + def register_output_transform( + self, + transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]], + ): + """Register output transform. + + Args: + transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]): + Output transform of network, receive a single tensor dict and return a single tensor dict. + """ + self._output_transform = transform + + def __str__(self): + num_fc = 0 + num_conv = 0 + num_bn = 0 + for layer in self.sublayers(include_self=True): + if isinstance(layer, nn.Linear): + num_fc += 1 + elif isinstance(layer, (nn.Conv2D, nn.Conv3D, nn.Conv1D)): + num_conv += 1 + elif isinstance(layer, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm3D)): + num_bn += 1 + + return ", ".join( + [ + self.__class__.__name__, + f"input_keys = {self.input_keys}", + f"output_keys = {self.output_keys}", + f"num_fc = {num_fc}", + f"num_conv = {num_conv}", + f"num_bn = {num_bn}", + f"num_params = {self.num_params}", + ] + ) diff --git a/jointContribution/yinglong/ppsci/arch/mlp.py b/jointContribution/yinglong/ppsci/arch/mlp.py new file mode 100644 index 0000000000..7eceb56948 --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/mlp.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from typing import Tuple +from typing import Union + +from paddle import nn + +from ppsci.arch import activation as act_mod +from ppsci.arch import base + + +class MLP(base.Arch): + """Multi layer perceptron network. + + Args: + input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z"). + output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w"). + num_layers (int): Number of hidden layers. + hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size. + An integer for all layers, or list of integer specify each layer's size. + activation (str, optional): Name of activation function. Defaults to "tanh". + skip_connection (bool, optional): Whether to use skip connection. Defaults to False. + weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False. + input_dim (Optional[int], optional): Number of input's dimension. Defaults to None. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128) + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + num_layers: int, + hidden_size: Union[int, Tuple[int, ...]], + activation: str = "tanh", + skip_connection: bool = False, + weight_norm: bool = False, + input_dim: Optional[int] = None, + ): + super().__init__() + self.input_keys = input_keys + self.output_keys = output_keys + self.linears = [] + if isinstance(hidden_size, (tuple, list)): + if num_layers is not None: + raise ValueError( + "num_layers should be None when hidden_size is specified" + ) + elif isinstance(hidden_size, int): + if not isinstance(num_layers, int): + raise ValueError( + "num_layers should be an int when hidden_size is an int" + ) + hidden_size = [hidden_size] * num_layers + else: + raise ValueError( + f"hidden_size should be list of int or int" + f"but got {type(hidden_size)}" + ) + + # initialize FC layer(s) + cur_size = len(self.input_keys) if input_dim is None else input_dim + for _size in hidden_size: + self.linears.append(nn.Linear(cur_size, _size)) + if weight_norm: + self.linears[-1] = nn.utils.weight_norm(self.linears[-1], dim=1) + cur_size = _size + self.linears = nn.LayerList(self.linears) + + self.last_fc = nn.Linear(cur_size, len(self.output_keys)) + + # initialize activation function + self.act = act_mod.get_activation(activation) + + self.skip_connection = skip_connection + + def forward_tensor(self, x): + y = x + skip = None + for i, linear in enumerate(self.linears): + y = linear(y) + if self.skip_connection and i % 2 == 0: + if skip is not None: + skip = y + y = y + skip + else: + skip = y + y = self.act(y) + + y = self.last_fc(y) + + return y + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + y = self.concat_to_tensor(x, self.input_keys, axis=-1) + y = self.forward_tensor(y) + y = self.split_to_dict(y, self.output_keys, axis=-1) + + if self._output_transform is not None: + y = self._output_transform(y) + return y diff --git a/jointContribution/yinglong/ppsci/arch/time_embedding.py b/jointContribution/yinglong/ppsci/arch/time_embedding.py new file mode 100644 index 0000000000..4745238d30 --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/time_embedding.py @@ -0,0 +1,31 @@ +import pandas as pd +from .timefeatures import time_features +import numpy as np +from paddle import nn +import paddle +def set_time(time_stamp): + + time_stamp = [pd.to_datetime(date_str, format='%Y/%m/%d/%H') for date_str in time_stamp] + time_stamp = pd.DataFrame({'date': time_stamp}) + + + time_feature = time_features(time_stamp, timeenc=1, freq='h').astype(np.float32) + time_feature = paddle.to_tensor(time_feature) + + return time_feature + +class TimeFeatureEmbedding(nn.Layer): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model) + + def forward(self, x, seq_len): + x = set_time(x) + + time_feature = self.embed(x) + time_feature = time_feature.unsqueeze(1) + time_feature = paddle.expand(time_feature ,[time_feature .shape[0], seq_len, time_feature .shape[2]]) + return time_feature \ No newline at end of file diff --git a/jointContribution/yinglong/ppsci/arch/timefeatures.py b/jointContribution/yinglong/ppsci/arch/timefeatures.py new file mode 100644 index 0000000000..2a61ef6038 --- /dev/null +++ b/jointContribution/yinglong/ppsci/arch/timefeatures.py @@ -0,0 +1,151 @@ +from typing import List + +import numpy as np +import pandas as pd +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset + +class TimeFeature: + def __init__(self): + pass + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + pass + + def __repr__(self): + return self.__class__.__name__ + "()" + +class SecondOfMinute(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.second / 59.0 - 0.5 + +class MinuteOfHour(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.minute / 59.0 - 0.5 + +class HourOfDay(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.hour / 23.0 - 0.5 + +class DayOfWeek(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.dayofweek / 6.0 - 0.5 + +class DayOfMonth(TimeFeature): + """Day of month encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.day - 1) / 30.0 - 0.5 + +class DayOfYear(TimeFeature): + """Day of year encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.dayofyear - 1) / 365.0 - 0.5 + +class MonthOfYear(TimeFeature): + """Month of year encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.month - 1) / 11.0 - 0.5 + +class WeekOfYear(TimeFeature): + """Week of year encoded as value between [-0.5, 0.5]""" + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.week - 1) / 52.0 - 0.5 + +def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: + """ + Returns a list of time features that will be appropriate for the given frequency string. + Parameters + ---------- + freq_str + Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. + """ + + features_by_offsets = { + offsets.YearEnd: [], + offsets.QuarterEnd: [MonthOfYear], + offsets.MonthEnd: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + offsets.Second: [ + SecondOfMinute, + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + } + + offset = to_offset(freq_str) + + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] + + supported_freq_msg = f""" + Unsupported frequency {freq_str} + The following frequencies are supported: + Y - yearly + alias: A + M - monthly + W - weekly + D - daily + B - business days + H - hourly + T - minutely + alias: min + S - secondly + """ + raise RuntimeError(supported_freq_msg) + +def time_features(dates, timeenc=1, freq='h'): + """ + > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0: + > * m - [month] + > * w - [month] + > * d - [month, day, weekday] + > * b - [month, day, weekday] + > * h - [month, day, weekday, hour] + > * t - [month, day, weekday, hour, *minute] + > + > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]): + > * Q - [month] + > * M - [month] + > * W - [Day of month, week of year] + > * D - [Day of week, day of month, day of year] + > * B - [Day of week, day of month, day of year] + > * H - [Hour of day, day of week, day of month, day of year] + > * T - [Minute of hour*, hour of day, day of week, day of month, day of year] + > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year] + + *minute returns a number from 0-3 corresponding to the 15 minute period it falls into. + """ + if timeenc==0: + dates['month'] = dates.date.apply(lambda row:row.month,1) + dates['day'] = dates.date.apply(lambda row:row.day,1) + dates['weekday'] = dates.date.apply(lambda row:row.weekday(),1) + dates['hour'] = dates.date.apply(lambda row:row.hour,1) + dates['minute'] = dates.date.apply(lambda row:row.minute,1) + dates['minute'] = dates.minute.map(lambda x:x//15) + freq_map = { + 'y':[],'m':['month'],'w':['month'],'d':['month','day','weekday'], + 'b':['month','day','weekday'],'h':['month','day','weekday','hour'], + 't':['month','day','weekday','hour','minute'], + } + return dates[freq_map[freq.lower()]].values + if timeenc==1: + dates = pd.to_datetime(dates.date.values) + return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0) diff --git a/jointContribution/yinglong/ppsci/constraint/__init__.py b/jointContribution/yinglong/ppsci/constraint/__init__.py new file mode 100644 index 0000000000..be14e4d200 --- /dev/null +++ b/jointContribution/yinglong/ppsci/constraint/__init__.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from ppsci.constraint.base import Constraint +# from ppsci.constraint.boundary_constraint import BoundaryConstraint +# from ppsci.constraint.initial_constraint import InitialConstraint +# from ppsci.constraint.integral_constraint import IntegralConstraint +# from ppsci.constraint.interior_constraint import InteriorConstraint +# from ppsci.constraint.periodic_constraint import PeriodicConstraint +from ppsci.constraint.supervised_constraint import SupervisedConstraint +from ppsci.loss import build_loss +from ppsci.utils import logger +from ppsci.utils import misc + +__all__ = [ + "Constraint", + "BoundaryConstraint", + "InitialConstraint", + "IntegralConstraint", + "InteriorConstraint", + "PeriodicConstraint", + "SupervisedConstraint", +] + + +def build_constraint(cfg, equation_dict, geom_dict): + """Build constraint(s). + + Args: + cfg (List[AttrDict]): Constraint config list. + equation_dict (Dct[str, Equation]): Equation(s) in dict. + geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. + + Returns: + Dict[str, constraint]: Constraint(s) in dict. + """ + if cfg is None: + return None + cfg = copy.deepcopy(cfg) + global_dataloader_cfg = cfg["dataloader"] + constraint_cfg = cfg["content"] + + constraint_dict = misc.PrettyOrderedDict() + for _item in constraint_cfg: + constraint_cls = next(iter(_item.keys())) + _constraint_cfg = _item[constraint_cls] + constraint_name = _constraint_cfg.get("name", constraint_cls) + + # select equation + if isinstance(_constraint_cfg["output_expr"], str): + equation_name = _constraint_cfg.pop("output_expr") + _constraint_cfg["output_expr"] = equation_dict[equation_name].equations + + # select geometry + geom_name = _constraint_cfg.pop("geom") + _constraint_cfg["geom"] = geom_dict[geom_name] + + # update complete dataloader config + local_dataloader_cfg = _constraint_cfg["dataloader"] + local_dataloader_cfg.update(global_dataloader_cfg) + + # build loss + _constraint_cfg["loss"] = build_loss(_constraint_cfg["loss"]) + + # instantiate constraint + _constraint_cfg["dataloader_cfg"] = _constraint_cfg.pop("dataloader") + constraint_dict[constraint_name] = eval(constraint_cls)(**_constraint_cfg) + + logger.debug(str(constraint_dict[constraint_name])) + + return constraint_dict diff --git a/jointContribution/yinglong/ppsci/constraint/base.py b/jointContribution/yinglong/ppsci/constraint/base.py new file mode 100644 index 0000000000..f79cace004 --- /dev/null +++ b/jointContribution/yinglong/ppsci/constraint/base.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Dict + +from paddle import io + +from ppsci import data +from ppsci import loss + + +class Constraint: + """Base class for constraint. + + Args: + dataset (io.Dataset): Dataset. + dataloader_cfg (Dict[str, Any]): Dataloader config. + loss (loss.Loss): Loss functor. + name (str): Name of constraint. + """ + + def __init__( + self, + dataset: io.Dataset, + dataloader_cfg: Dict[str, Any], + loss: loss.Loss, + name: str, + ): + self.data_loader = data.build_dataloader(dataset, dataloader_cfg) + # self.data_loader = data.dataloader.InfiniteDataLoader(self.data_loader) + self.data_iter = iter(self.data_loader) + self.loss = loss + self.name = name + + def __str__(self): + return ", ".join( + [ + self.__class__.__name__, + f"name = {self.name}", + f"input_keys = {self.input_keys}", + f"output_keys = {self.output_keys}", + f"output_expr = {self.output_expr}", + f"label_dict = {self.label_dict}", + f"loss = {self.loss}", + ] + ) diff --git a/jointContribution/yinglong/ppsci/constraint/supervised_constraint.py b/jointContribution/yinglong/ppsci/constraint/supervised_constraint.py new file mode 100644 index 0000000000..40b8bb63a4 --- /dev/null +++ b/jointContribution/yinglong/ppsci/constraint/supervised_constraint.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional + +from ppsci import loss +from ppsci.constraint import base +from ppsci.data import dataset + + +class SupervisedConstraint(base.Constraint): + """Class for supervised constraint. + + Args: + dataloader_cfg (Dict[str, Any]): Dataloader config. + loss (loss.Loss): Loss functor. + output_expr (Optional[Dict[str, Callable]]): List of label expression. + Defaults to None. + name (str, optional): Name of constraint object. Defaults to "Sup". + + Examples: + >>> import ppsci + >>> bc_sup = ppsci.constraint.SupervisedConstraint( + ... { + ... "dataset": { + ... "name": "IterableCSVDataset", + ... "file_path": "/path/to/file.csv", + ... "input_keys": ("x", "y"), + ... "label_keys": ("u", "v"), + ... }, + ... }, + ... ppsci.loss.MSELoss("mean"), + ... name="bc_sup", + ... ) # doctest: +SKIP + """ + + def __init__( + self, + dataloader_cfg: Dict[str, Any], + loss: loss.Loss, + output_expr: Optional[Dict[str, Callable]] = None, + name: str = "Sup", + ): + self.output_expr = output_expr + + # build dataset + _dataset = dataset.build_dataset(dataloader_cfg["dataset"]) + + self.input_keys = _dataset.input_keys + self.output_keys = ( + list(output_expr.keys()) if output_expr is not None else _dataset.label_keys + ) + + if self.output_expr is None: + self.output_expr = { + key: lambda out, k=key: out[k] for key in self.output_keys + } + + # construct dataloader with dataset and dataloader_cfg + super().__init__(_dataset, dataloader_cfg, loss, name) + + def __str__(self): + return ", ".join( + [ + self.__class__.__name__, + f"name = {self.name}", + f"input_keys = {self.input_keys}", + f"output_keys = {self.output_keys}", + f"output_expr = {self.output_expr}", + f"loss = {self.loss}", + ] + ) diff --git a/jointContribution/yinglong/ppsci/data/__init__.py b/jointContribution/yinglong/ppsci/data/__init__.py new file mode 100644 index 0000000000..b1a521b4b5 --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/__init__.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +from functools import partial + +import numpy as np +import paddle.distributed as dist +from paddle import device +from paddle import io + +from ppsci.data import dataloader +from ppsci.data import dataset +from ppsci.data import process +from ppsci.data.process import batch_transform +from ppsci.data.process import transform +from ppsci.utils import logger + +__all__ = [ + "dataset", + "process", + "dataloader", + "build_dataloader", + "transform", + "batch_transform", +] + + +def worker_init_fn(worker_id, num_workers, rank, base_seed): + """Callback function on each worker subprocess after seeding and before data loading. + + Args: + worker_id (int): Worker id in [0, num_workers - 1] + num_workers (int): Number of subprocesses to use for data loading. + rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`. + seed (int): Random seed + """ + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + base_seed + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def build_dataloader(_dataset, cfg): + world_size = dist.get_world_size() + # just return IterableDataset as datalaoder + if isinstance(_dataset, io.IterableDataset): + if world_size > 1: + raise ValueError( + f"world_size({world_size}) should be 1 when using IterableDataset" + ) + return _dataset + + cfg = copy.deepcopy(cfg) + + # build sampler + sampler_cfg = cfg.pop("sampler") + sampler_cls = sampler_cfg.pop("name") + if sampler_cls == "BatchSampler": + if world_size > 1: + sampler_cls = "DistributedBatchSampler" + logger.warning( + f"Automatically use 'DistributedBatchSampler' instead of " + f"'BatchSampler' when world_size({world_size}) > 1" + ) + + sampler_cfg["batch_size"] = cfg["batch_size"] + sampler = getattr(io, sampler_cls)(_dataset, **sampler_cfg) + + # build collate_fn if specified + batch_transforms_cfg = cfg.pop("batch_transforms", None) + + collate_fn = None + if isinstance(batch_transforms_cfg, dict) and batch_transforms_cfg: + collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg) + + # build init function + init_fn = partial( + worker_init_fn, + num_workers=cfg.get("num_workers", 0), + rank=dist.get_rank(), + base_seed=cfg.get("seed", 42), + ) + + # build dataloader + dataloader_ = io.DataLoader( + dataset=_dataset, + places=device.get_device(), + batch_sampler=sampler, + collate_fn=collate_fn, + num_workers=cfg.get("num_workers", 0), + use_shared_memory=cfg.get("use_shared_memory", False), + worker_init_fn=init_fn, + ) + + return dataloader_ diff --git a/jointContribution/yinglong/ppsci/data/dataloader.py b/jointContribution/yinglong/ppsci/data/dataloader.py new file mode 100644 index 0000000000..c01d538f28 --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/dataloader.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from paddle import io + + +class InfiniteDataLoader: + """A wrapper for infinite dataloader. + + Args: + dataloader (Union[io.DataLoader, io.IterableDataset]): A finite and iterable loader or iterable dataset to be wrapped. + """ + + def __init__(self, dataloader: Union[io.DataLoader, io.IterableDataset]): + self.dataloader = dataloader + if isinstance(dataloader, io.DataLoader): + self.dataset = dataloader.dataset + elif isinstance(dataloader, io.IterableDataset): + self.dataset = dataloader + else: + raise TypeError( + f"dataloader should be io.DataLoader or io.IterableDataset, but got {type(dataloader)}" + ) + + def __iter__(self): + while True: + dataloader_iter = iter(self.dataloader) + for batch in dataloader_iter: + yield batch + + def __len__(self): + return len(self.dataloader) diff --git a/jointContribution/yinglong/ppsci/data/dataset/__init__.py b/jointContribution/yinglong/ppsci/data/dataset/__init__.py new file mode 100644 index 0000000000..1441bd1121 --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/dataset/__init__.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from ppsci.data.dataset.array_dataset import IterableNamedArrayDataset +from ppsci.data.dataset.array_dataset import NamedArrayDataset +# from ppsci.data.dataset.csv_dataset import CSVDataset +# from ppsci.data.dataset.csv_dataset import IterableCSVDataset +# from ppsci.data.dataset.era5_dataset import ERA5Dataset +# from ppsci.data.dataset.era5_dataset import ERA5SampledDataset +from ppsci.data.dataset.hrrr_dataset import HRRRDataset +from ppsci.data.dataset.hrrr_dataset import HRRRDatasetMultiInput +# from ppsci.data.dataset.mat_dataset import IterableMatDataset +# from ppsci.data.dataset.mat_dataset import MatDataset +# from ppsci.data.dataset.trphysx_dataset import CylinderDataset +# from ppsci.data.dataset.trphysx_dataset import LorenzDataset +# from ppsci.data.dataset.trphysx_dataset import RosslerDataset +# from ppsci.data.dataset.vtu_dataset import VtuDataset +from ppsci.data.process import transform +from ppsci.utils import logger + +__all__ = [ + "IterableNamedArrayDataset", + "NamedArrayDataset", + "CSVDataset", + "IterableCSVDataset", + "ERA5Dataset", + "ERA5SampledDataset", + "IterableMatDataset", + "MatDataset", + "CylinderDataset", + "LorenzDataset", + "RosslerDataset", + "VtuDataset", + "build_dataset", + "HRRRDataset", + "HRRRDatasetMultiInput", +] + + +def build_dataset(cfg): + """Build dataset + + Args: + cfg (List[AttrDict]): dataset config list. + + Returns: + Dict[str, io.Dataset]: dataset. + """ + cfg = copy.deepcopy(cfg) + + dataset_cls = cfg.pop("name") + if "transforms" in cfg: + cfg["transforms"] = transform.build_transforms(cfg.pop("transforms")) + + dataset = eval(dataset_cls)(**cfg) + + logger.debug(str(dataset)) + + return dataset diff --git a/jointContribution/yinglong/ppsci/data/dataset/array_dataset.py b/jointContribution/yinglong/ppsci/data/dataset/array_dataset.py new file mode 100644 index 0000000000..6eab8364a4 --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/dataset/array_dataset.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from typing import Optional + +import numpy as np +import paddle +from paddle import io +from paddle import vision + + +class NamedArrayDataset(io.Dataset): + """Class for Named Array Dataset. + + Args: + input (Dict[str, np.ndarray]): Input dict. + label (Dict[str, np.ndarray]): Label dict. + weight (Dict[str, np.ndarray], optional): Weight dict. + transforms (Optional[vision.Compose]): Compose object contains sample wise + transform(s). + + Examples: + >>> import ppsci + >>> input = {"x": np.random.randn(100, 1)} + >>> output = {"u": np.random.randn(100, 1)} + >>> weight = {"u": np.random.randn(100, 1)} + >>> dataset = ppsci.data.dataset.NamedArrayDataset(input, output, weight) + """ + + def __init__( + self, + input: Dict[str, np.ndarray], + label: Dict[str, np.ndarray], + weight: Dict[str, np.ndarray], + transforms: Optional[vision.Compose] = None, + ): + super().__init__() + self.input = input + self.label = label + self.input_keys = tuple(input.keys()) + self.label_keys = tuple(label.keys()) + self.weight = weight + self.transforms = transforms + self._len = len(next(iter(input.values()))) + + def __getitem__(self, idx): + input_item = {key: value[idx] for key, value in self.input.items()} + label_item = {key: value[idx] for key, value in self.label.items()} + weight_item = {key: value[idx] for key, value in self.weight.items()} + + # TODO(sensen): Transforms may be applied on label and weight. + if self.transforms is not None: + input_item = self.transforms(input_item) + + return (input_item, label_item, weight_item) + + def __len__(self): + return self._len + + +class IterableNamedArrayDataset(io.IterableDataset): + """IterableNamedArrayDataset for full-data loading. + + Args: + input (Dict[str, np.ndarray]): Input dict. + label (Dict[str, np.ndarray]): Label dict. + weight (Dict[str, np.ndarray]): Weight dict. + transforms (Optional[vision.Compose]): Compose object contains sample wise + transform(s). Defaults to None. + + Examples: + >>> import ppsci + >>> input = {"x": np.random.randn(100, 1)} + >>> label = {"u": np.random.randn(100, 1)} + >>> weight = {"u": np.random.randn(100, 1)} + >>> dataset = ppsci.data.dataset.IterableNamedArrayDataset(input, label, weight) + """ + + def __init__( + self, + input: Dict[str, np.ndarray], + label: Dict[str, np.ndarray], + weight: Dict[str, np.ndarray], + transforms: Optional[vision.Compose] = None, + ): + super().__init__() + self.input = {key: paddle.to_tensor(value) for key, value in input.items()} + self.label = {key: paddle.to_tensor(value) for key, value in label.items()} + self.weight = {key: paddle.to_tensor(value) for key, value in weight.items()} + self._len = len(next(iter(self.input.values()))) + + @property + def num_samples(self): + """Number of samples within current dataset.""" + return self._len + + def __iter__(self): + yield self.input, self.label, self.weight + + def __len__(self): + return 1 diff --git a/jointContribution/yinglong/ppsci/data/dataset/hrrr_dataset.py b/jointContribution/yinglong/ppsci/data/dataset/hrrr_dataset.py new file mode 100644 index 0000000000..eff3a4dd5d --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/dataset/hrrr_dataset.py @@ -0,0 +1,298 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict +from typing import Optional +from typing import Tuple + +import h5py +import numpy as np +import paddle +from paddle import io +from paddle import vision + + +class HRRRDataset(io.Dataset): + """Class for HRRR dataset. + + Args: + file_path (str): Data set path. + input_keys (Tuple[str, ...]): Input keys, such as ("input",). + label_keys (Tuple[str, ...]): Output keys, such as ("output",). + precip_file_path (Optional[str]): Precipitation data set path. Defaults to None. + weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None. + vars_channel (Optional[Tuple[int, ...]]): The variable channel index in ERA5 dataset. Defaults to None. + num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1. + transforms (Optional[vision.Compose]): Compose object contains sample wise + transform(s). Defaults to None. + training (bool, optional): Whether in train mode. Defaults to True. + stride (int, optional): Stride of sampling data. Defaults to 1. + + Examples: + >>> import ppsci + >>> dataset = ppsci.data.dataset.ERA5Dataset( + ... "file_path": "/path/to/ERA5Dataset", + ... "input_keys": ("input",), + ... "label_keys": ("output",), + ... ) # doctest: +SKIP + """ + + def __init__( + self, + file_path: str, + input_keys: Tuple[str, ...], + label_keys: Tuple[str, ...], + weight_dict: Optional[Dict[str, float]] = None, + num_label_timestamps: int = 1, + vars_channel: Optional[Tuple[int, ...]] = None, + transforms: Optional[vision.Compose] = None, + training: bool = True, + stride: int = 1, + lead_time: int = 1, + extra_file_path: Optional[str] = None, + extra_vars_channel: Optional[Tuple[int, ...]] = None, + merge_label: bool = False, + ): + super().__init__() + self.file_path = file_path + self.input_keys = input_keys + self.label_keys = label_keys + + self.weight_dict = {key: 1.0 for key in self.label_keys} + if weight_dict is not None: + self.weight_dict.update(weight_dict) + + self.vars_channel = list(range(69)) if vars_channel is None else vars_channel + self.num_label_timestamps = num_label_timestamps + self.transforms = transforms + self.training = training + self.stride = stride + self.lead_time = lead_time + self.extra_file_path = extra_file_path + self.extra_vars_channel = extra_vars_channel + self.merge_label = merge_label + + self.files = self.read_data(file_path, extra_file_path) + self.num_days = len(self.files) + self.num_samples_per_day = self.files[0][0].shape[0] + self.num_samples = self.num_days * self.num_samples_per_day + + def read_data(self, path: str, extra_path: str, var="fields"): + if path.endswith(".h5"): + paths = [path] + else: + paths = [] + for root, dirs, files in os.walk(path): + for file in files: + paths.append(os.path.join(root, file)) + paths.sort() + files = [] + for path_ in paths: + _file = h5py.File(path_, "r") + if extra_path is not None: + _extra_file = h5py.File(os.path.join(extra_path, path_[-13:]), "r") + files.append([_file[var], path_[-13:-3], _extra_file[var]]) + else: + files.append([_file[var], path_[-13:-3]]) + + return files + + def __len__(self): + return self.num_samples // self.stride + + def __getitem__(self, global_idx): + + global_idx *= self.stride + + if global_idx >= self.num_samples - self.num_label_timestamps - self.lead_time: + return self.__getitem__(np.random.randint(self.__len__())) + + input_day_idx = global_idx // self.num_samples_per_day + input_hour_idx = global_idx % self.num_samples_per_day + + input_file = self.files[input_day_idx][0] + # check fake data + if len(input_file.shape) == 1: + print("Warning: fake data detected, please check your data") + return self.__getitem__(np.random.randint(self.__len__())) + input_item = {self.input_keys[0]: input_file[input_hour_idx, self.vars_channel]} + if self.extra_file_path is not None: + extra_input = self.files[input_day_idx][2][ + input_hour_idx, self.extra_vars_channel + ] + input_item[self.input_keys[0]] = np.concatenate( + [input_item[self.input_keys[0]], extra_input] + ) + + # label_item = {self.label_keys[0]: label_file[label_hour_idx, self.vars_channel]} + input_time_list = [] + input_time = self.files[input_day_idx][1] + "/" + str(input_hour_idx) + input_time_list.append(input_time) + + label_item = {} + label_time = {} + + for i in range(self.num_label_timestamps): + label_day_idx = ( + global_idx + self.lead_time + i + ) // self.num_samples_per_day + label_hour_idx = ( + global_idx + self.lead_time + i + ) % self.num_samples_per_day + label_file = self.files[label_day_idx][0] + if len(label_file.shape) == 1: + print("Warning: fake data detected, please check your data") + return self.__getitem__(np.random.randint(self.__len__())) + label_item[self.label_keys[i]] = label_file[ + label_hour_idx, self.vars_channel + ] + if self.extra_file_path is not None: + extra_label = self.files[label_day_idx][2][ + label_hour_idx, self.extra_vars_channel + ] + label_item[self.label_keys[i]] = np.concatenate( + [label_item[self.label_keys[i]], extra_label] + ) + + label_time[self.label_keys[i]] = ( + self.files[label_day_idx][1] + "/" + str(label_hour_idx) + ) + input_time = self.files[label_day_idx][1] + "/" + str(label_hour_idx) + input_time_list.append(input_time) + # merge label + if self.merge_label: + for i in range(self.num_label_timestamps): + input_item[f"{self.input_keys[0]}_{i}_merge"] = label_item[ + self.label_keys[i] + ] + # import remote_pdb as pdb;pdb.set_trace() + weight_shape = [1] * len(next(iter(label_item.values())).shape) + weight_item = { + key: np.full(weight_shape, value, paddle.get_default_dtype()) + for key, value in self.weight_dict.items() + } + + if self.transforms is not None: + input_item, label_item, weight_item = self.transforms( + (input_item, label_item, weight_item) + ) + + return input_item, label_item, weight_item, input_time_list + + +class HRRRDatasetMultiInput(HRRRDataset): + """Class for HRRR dataset. + + Args: + file_path (str): Data set path. + input_keys (Tuple[str, ...]): Input keys, such as ("input",). + label_keys (Tuple[str, ...]): Output keys, such as ("output",). + precip_file_path (Optional[str]): Precipitation data set path. Defaults to None. + weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None. + vars_channel (Optional[Tuple[int, ...]]): The variable channel index in ERA5 dataset. Defaults to None. + num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1. + transforms (Optional[vision.Compose]): Compose object contains sample wise + transform(s). Defaults to None. + training (bool, optional): Whether in train mode. Defaults to True. + stride (int, optional): Stride of sampling data. Defaults to 1. + + Examples: + >>> import ppsci + >>> dataset = ppsci.data.dataset.ERA5Dataset( + ... "file_path": "/path/to/ERA5Dataset", + ... "input_keys": ("input",), + ... "label_keys": ("output",), + ... ) # doctest: +SKIP + """ + + def __init__( + self, + file_path: str, + input_keys: Tuple[str, ...], + label_keys: Tuple[str, ...], + weight_dict: Optional[Dict[str, float]] = None, + num_input_timestamps: int = 1, + num_label_timestamps: int = 1, + vars_channel: Optional[Tuple[int, ...]] = None, + transforms: Optional[vision.Compose] = None, + training: bool = True, + stride: int = 1, + ): + super().__init__( + file_path=file_path, + input_keys=input_keys, + label_keys=label_keys, + weight_dict=weight_dict, + num_label_timestamps=num_label_timestamps, + vars_channel=vars_channel, + transforms=transforms, + training=training, + stride=stride, + ) + self.num_input_timestamps = num_input_timestamps + + def __len__(self): + return (self.num_samples - self.num_input_timestamps) // self.stride + + def __getitem__(self, global_idx): + + global_idx = global_idx * self.stride + self.num_input_timestamps + + if ( + global_idx < (self.num_input_timestamps - 1) + or global_idx >= self.num_samples - self.num_label_timestamps + ): + return self.__getitem__(np.random.randint(self.__len__())) + + input_item = {} + for i in range(self.num_input_timestamps): + + input_day_idx = (global_idx - i) // self.num_samples_per_day + input_hour_idx = (global_idx - i) % self.num_samples_per_day + + input_file = self.files[input_day_idx] + # check fake data + if len(input_file.shape) == 1: + print("Warning: fake data detected, please check your data") + return self.__getitem__(np.random.randint(self.__len__())) + input_item[self.input_keys[i]] = input_file[ + input_hour_idx, self.vars_channel + ] + # input_item = {self.input_keys[0]: input_file[input_hour_idx, self.vars_channel]} + # label_item = {self.label_keys[0]: label_file[label_hour_idx, self.vars_channel]} + + label_item = {} + for i in range(self.num_label_timestamps): + label_day_idx = (global_idx + 1 + i) // self.num_samples_per_day + label_hour_idx = (global_idx + 1 + i) % self.num_samples_per_day + label_file = self.files[label_day_idx] + if len(label_file.shape) == 1: + print("Warning: fake data detected, please check your data") + return self.__getitem__(np.random.randint(self.__len__())) + label_item[self.label_keys[i]] = label_file[ + label_hour_idx, self.vars_channel + ] + weight_shape = [1] * len(next(iter(label_item.values())).shape) + weight_item = { + key: np.full(weight_shape, value, paddle.get_default_dtype()) + for key, value in self.weight_dict.items() + } + + if self.transforms is not None: + input_item, label_item, weight_item = self.transforms( + (input_item, label_item, weight_item) + ) + + return input_item, label_item, weight_item diff --git a/jointContribution/yinglong/ppsci/data/process/__init__.py b/jointContribution/yinglong/ppsci/data/process/__init__.py new file mode 100644 index 0000000000..f46c8dd9cf --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/process/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppsci.data.process import batch_transform +from ppsci.data.process import transform + +__all__ = [ + "batch_transform", + "transform", +] diff --git a/jointContribution/yinglong/ppsci/data/process/batch_transform/__init__.py b/jointContribution/yinglong/ppsci/data/process/batch_transform/__init__.py new file mode 100644 index 0000000000..2ac5ac75e8 --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/process/batch_transform/__init__.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import numbers +from collections.abc import Mapping +from collections.abc import Sequence +from typing import Any +from typing import List + +import numpy as np +import paddle +from paddle.fluid import core + +from ppsci.data.process import transform + +__all__ = ["build_batch_transforms"] + + +def default_collate_fn(batch: List[Any]) -> Any: + """Default_collate_fn for paddle dataloader. + + ref: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/io/dataloader/collate.py#L25 + + Args: + batch (List[Any]): Batch of samples to be collated. + + Returns: + Any: Collated batch data. + """ + sample = batch[0] + if isinstance(sample, np.ndarray): + batch = np.stack(batch, axis=0) + return batch + elif isinstance(sample, (paddle.Tensor, core.eager.Tensor)): + return paddle.stack(batch, axis=0) + elif isinstance(sample, numbers.Number): + batch = np.array(batch) + return batch + elif isinstance(sample, (str, bytes)): + return batch + elif isinstance(sample, Mapping): + return {key: default_collate_fn([d[key] for d in batch]) for key in sample} + elif isinstance(sample, Sequence): + sample_fields_num = len(sample) + if not all(len(sample) == sample_fields_num for sample in iter(batch)): + raise RuntimeError("fileds number not same among samples in a batch") + return [default_collate_fn(fields) for fields in zip(*batch)] + + raise TypeError( + "batch data can only contains: tensor, numpy.ndarray, " + f"dict, list, number, None, but got {type(sample)}" + ) + + +def build_batch_transforms(cfg): + cfg = copy.deepcopy(cfg) + batch_transforms = transform.build_transforms(cfg) + + def collate_fn_batch_transforms(batch: List[Any]): + # apply batch transform on uncollated data + batch = batch_transforms(batch) + # then do collate + return default_collate_fn(batch) + + return collate_fn_batch_transforms diff --git a/jointContribution/yinglong/ppsci/data/process/batch_transform/preprocess.py b/jointContribution/yinglong/ppsci/data/process/batch_transform/preprocess.py new file mode 100644 index 0000000000..66987aa2fe --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/process/batch_transform/preprocess.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jointContribution/yinglong/ppsci/data/process/transform/__init__.py b/jointContribution/yinglong/ppsci/data/process/transform/__init__.py new file mode 100644 index 0000000000..6ad9b8cdf0 --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/process/transform/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from ppsci.data.process.postprocess import * +import copy + +from paddle import vision + +from ppsci.data.process.transform.preprocess import CropData +from ppsci.data.process.transform.preprocess import Log1p +from ppsci.data.process.transform.preprocess import Normalize +from ppsci.data.process.transform.preprocess import Scale +from ppsci.data.process.transform.preprocess import SqueezeData +from ppsci.data.process.transform.preprocess import Translate + +__all__ = [ + "CropData", + "Log1p", + "Normalize", + "Scale", + "SqueezeData", + "Translate", + "build_transforms", +] + + +def build_transforms(cfg): + if not cfg: + return vision.Compose([]) + cfg = copy.deepcopy(cfg) + + transform_list = [] + for _item in cfg: + transform_cls = next(iter(_item.keys())) + transform_cfg = _item[transform_cls] + transform = eval(transform_cls)(**transform_cfg) + transform_list.append(transform) + + return vision.Compose(transform_list) diff --git a/jointContribution/yinglong/ppsci/data/process/transform/preprocess.py b/jointContribution/yinglong/ppsci/data/process/transform/preprocess.py new file mode 100644 index 0000000000..44f078a63b --- /dev/null +++ b/jointContribution/yinglong/ppsci/data/process/transform/preprocess.py @@ -0,0 +1,219 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from typing import Tuple +from typing import Union + +import numpy as np + + +class Translate: + """Translate class. + + Args: + offset (Dict[str, float]): Shift the input data according to the variable name + and coefficient specified in offset. + + Examples: + >>> import ppsci + >>> translate = ppsci.data.transform.Translate({"x": 1.0, "y": -1.0}) + """ + + def __init__(self, offset: Dict[str, float]): + self.offset = offset + + def __call__(self, data_dict): + for key in self.offset: + if key in data_dict: + data_dict[key] += self.offset[key] + return data_dict + + +class Scale: + """Scale class. + + Args: + scale (Dict[str, float]): Scale the input data according to the variable name + and coefficient specified in scale. + + Examples: + >>> import ppsci + >>> translate = ppsci.data.transform.Scale({"x": 1.5, "y": 2.0}) + """ + + def __init__(self, scale: Dict[str, float]): + self.scale = scale + + def __call__(self, data_dict): + for key in self.scale: + if key in data_dict: + data_dict[key] *= self.scale[key] + return data_dict + + +class Normalize: + """Normalize data class. + + Args: + mean (Union[np.array, Tuple[float, ...]]): Mean of training dataset. + std (Union[np.array, Tuple[float, ...]]): Standard Deviation of training dataset. + apply_keys (Tuple[str, ...], optional): Which data is the normalization method applied to. Defaults to ("input", "label"). + + Examples: + >>> import ppsci + >>> normalize = ppsci.data.transform.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)) + """ + + def __init__( + self, + mean: Union[np.array, Tuple[float, ...]], + std: Union[np.array, Tuple[float, ...]], + apply_keys: Tuple[str, ...] = ("input", "label"), + ): + if len(apply_keys) == 0 or len(set(apply_keys) | {"input", "label"}) > 2: + raise ValueError( + f"apply_keys should be a non empty subset of ('input', 'label'), but got {apply_keys}" + ) + self.mean = mean + self.std = std + self.apply_keys = apply_keys + + def __call__(self, data): + input_item, label_item, weight_item = data + if "input" in self.apply_keys: + for key, value in input_item.items(): + input_item[key] = (value - self.mean) / self.std + if "label" in self.apply_keys: + for key, value in label_item.items(): + label_item[key] = (value - self.mean) / self.std + return input_item, label_item, weight_item + + +class Log1p: + """Calculates the natural logarithm of one plus the data, element-wise. + + Args: + scale (float, optional): Scale data. Defaults to 1.0. + apply_keys (Tuple[str, ...], optional): Which data is the log1p method applied to. Defaults to ("input", "label"). + + Examples: + >>> import ppsci + >>> log1p = ppsci.data.transform.Log1p(1e-5) + """ + + def __init__( + self, + scale: float = 1.0, + apply_keys: Tuple[str, ...] = ("input", "label"), + ): + if len(apply_keys) == 0 or len(set(apply_keys) | {"input", "label"}) > 2: + raise ValueError( + f"apply_keys should be a non empty subset of ('input', 'label'), but got {apply_keys}" + ) + self.scale = scale + self.apply_keys = apply_keys + + def __call__(self, data): + input_item, label_item, weight_item = data + if "input" in self.apply_keys: + for key, value in input_item.items(): + input_item[key] = np.log1p(value / self.scale) + if "label" in self.apply_keys: + for key, value in label_item.items(): + label_item[key] = np.log1p(value / self.scale) + return input_item, label_item, weight_item + + +class CropData: + """Crop data class. + + Args: + xmin (Tuple[int, ...]): Bottom left corner point, [x0, y0]. + xmax (Tuple[int, ...]): Top right corner point, [x1, y1]. + apply_keys (Tuple[str, ...], optional): Which data is the crop method applied to. Defaults to ("input", "label"). + + Examples: + >>> import ppsci + >>> crop_data = ppsci.data.transform.CropData((0, 0), (720, 1440)) + """ + + def __init__( + self, + xmin: Tuple[int, ...], + xmax: Tuple[int, ...], + apply_keys: Tuple[str, ...] = ("input", "label"), + ): + if len(apply_keys) == 0 or len(set(apply_keys) | {"input", "label"}) > 2: + raise ValueError( + f"apply_keys should be a non empty subset of ('input', 'label'), but got {apply_keys}" + ) + self.xmin = xmin + self.xmax = xmax + self.apply_keys = apply_keys + + def __call__(self, data): + input_item, label_item, weight_item = data + if "input" in self.apply_keys: + for key, value in input_item.items(): + input_item[key] = value[ + :, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1] + ] + if "label" in self.apply_keys: + for key, value in label_item.items(): + label_item[key] = value[ + :, self.xmin[0] : self.xmax[0], self.xmin[1] : self.xmax[1] + ] + return input_item, label_item, weight_item + + +class SqueezeData: + """Squeeze data clsss. + + Args: + apply_keys (Tuple[str, ...], optional): Which data is the squeeze method applied to. Defaults to ("input", "label"). + + Examples: + >>> import ppsci + >>> squeeze_data = ppsci.data.transform.SqueezeData() + """ + + def __init__(self, apply_keys: Tuple[str, ...] = ("input", "label")): + if len(apply_keys) == 0 or len(set(apply_keys) | {"input", "label"}) > 2: + raise ValueError( + f"apply_keys should be a non empty subset of ('input', 'label'), but got {apply_keys}" + ) + self.apply_keys = apply_keys + + def __call__(self, data): + input_item, label_item, weight_item = data + if "input" in self.apply_keys: + for key, value in input_item.items(): + if value.ndim == 4: + B, C, H, W = value.shape + input_item[key] = value.reshape((B * C, H, W)) + if value.ndim != 3: + raise ValueError( + f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}" + ) + if "label" in self.apply_keys: + for key, value in label_item.items(): + if value.ndim == 4: + B, C, H, W = value.shape + label_item[key] = value.reshape((B * C, H, W)) + if value.ndim != 3: + raise ValueError( + f"Only support squeeze data to ndim=3 now, but got ndim={value.ndim}" + ) + return input_item, label_item, weight_item diff --git a/jointContribution/yinglong/ppsci/loss/__init__.py b/jointContribution/yinglong/ppsci/loss/__init__.py new file mode 100644 index 0000000000..da0b759d38 --- /dev/null +++ b/jointContribution/yinglong/ppsci/loss/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +# from ppsci.loss.anomaly_coef_loss import ACCLoss +from ppsci.loss.base import Loss + +# from ppsci.loss.integral import IntegralLoss +# from ppsci.loss.l1 import L1Loss +# from ppsci.loss.l1 import PeriodicL1Loss +from ppsci.loss.l2 import L2Loss +from ppsci.loss.l2 import L2RelLoss +from ppsci.loss.l2 import PeriodicL2Loss + +# from ppsci.loss.mse import MSELoss +# from ppsci.loss.mse import MSELossWithL2Decay +# from ppsci.loss.mse import PeriodicMSELoss +# from ppsci.loss.multi_loss import MultiLoss + +__all__ = [ + "Loss", + # "IntegralLoss", + # "L1Loss", + # "PeriodicL1Loss", + "L2Loss", + "L2RelLoss", + "PeriodicL2Loss", + # "MSELoss", + # "MSELossWithL2Decay", + # "PeriodicMSELoss", + # "ACCLoss", + # "MultiLoss", +] + + +def build_loss(cfg): + """Build loss. + + Args: + cfg (AttrDict): Loss config. + Returns: + Loss: Callable loss object. + """ + cfg = copy.deepcopy(cfg) + + loss_cls = cfg.pop("name") + loss = eval(loss_cls)(**cfg) + return loss diff --git a/jointContribution/yinglong/ppsci/loss/base.py b/jointContribution/yinglong/ppsci/loss/base.py new file mode 100644 index 0000000000..e9ac8c17d8 --- /dev/null +++ b/jointContribution/yinglong/ppsci/loss/base.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from typing import Optional +from typing import Union + +from paddle import nn +from typing_extensions import Literal + + +class Loss(nn.Layer): + """Base class for loss.""" + + def __init__( + self, + reduction: Literal["mean", "sum"], + weight: Optional[Union[float, Dict[str, float]]] = None, + ): + super().__init__() + self.reduction = reduction + self.weight = weight + + def __str__(self): + return f"{self.__class__.__name__}(reduction={self.reduction}, weight={self.weight})" diff --git a/jointContribution/yinglong/ppsci/loss/l2.py b/jointContribution/yinglong/ppsci/loss/l2.py new file mode 100644 index 0000000000..412d8d1f6e --- /dev/null +++ b/jointContribution/yinglong/ppsci/loss/l2.py @@ -0,0 +1,216 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from typing import Optional +from typing import Union + +import paddle +import paddle.nn.functional as F +from typing_extensions import Literal + +from ppsci.loss import base + + +class L2Loss(base.Loss): + r"""Class for l2 loss. + + $$ + L =\Vert \mathbf{x} - \mathbf{y} \Vert_2 + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} + $$ + + Args: + reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". + weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. + + Examples: + >>> import ppsci + >>> loss = ppsci.loss.L2Loss() + """ + + def __init__( + self, + reduction: Literal["mean", "sum"] = "mean", + weight: Optional[Union[float, Dict[str, float]]] = None, + ): + if reduction not in ["mean", "sum"]: + raise ValueError( + f"reduction should be 'mean' or 'sum', but got {reduction}" + ) + super().__init__(reduction, weight) + + def forward(self, output_dict, label_dict, weight_dict=None): + losses = 0.0 + for key in label_dict: + loss = F.mse_loss(output_dict[key], label_dict[key], "none") + if weight_dict is not None: + loss *= weight_dict[key] + + if "area" in output_dict: + loss *= output_dict["area"] + + loss = loss.sum(axis=1).sqrt() + + if self.reduction == "sum": + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + + if isinstance(self.weight, (float, int)): + loss *= self.weight + elif isinstance(self.weight, dict) and key in self.weight: + loss *= self.weight[key] + + losses += loss + return losses + + +class PeriodicL2Loss(base.Loss): + r"""Class for Periodic l2 loss. + + $$ + L = \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2 + $$ + + $\mathbf{x_l} \in \mathcal{R}^{N}$ is the first half of batch output, + $\mathbf{x_r} \in \mathcal{R}^{N}$ is the second half of batch output. + + Args: + reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". + weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. + + Examples: + >>> import ppsci + >>> loss = ppsci.loss.PeriodicL2Loss() + """ + + def __init__( + self, + reduction: Literal["mean", "sum"] = "mean", + weight: Optional[Union[float, Dict[str, float]]] = None, + ): + if reduction not in ["mean", "sum"]: + raise ValueError( + f"reduction should be 'mean' or 'sum', but got {reduction}" + ) + super().__init__(reduction, weight) + + def forward(self, output_dict, label_dict, weight_dict=None): + losses = 0.0 + for key in label_dict: + n_output = len(output_dict[key]) + if n_output % 2 > 0: + raise ValueError( + f"Length of output({n_output}) of key({key}) should be even." + ) + n_output //= 2 + + loss = F.mse_loss( + output_dict[key][:n_output], output_dict[key][n_output:], "none" + ) + if weight_dict: + loss *= weight_dict[key] + + if "area" in output_dict: + loss *= output_dict["area"] + + loss = loss.sum(axis=1).sqrt() + + if self.reduction == "sum": + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + + if isinstance(self.weight, (float, int)): + loss *= self.weight + elif isinstance(self.weight, dict) and key in self.weight: + loss *= self.weight[key] + + losses += loss + return losses + + +class L2RelLoss(base.Loss): + r"""Class for l2 relative loss. + + $$ + L = \dfrac{\Vert \mathbf{x} - \mathbf{y} \Vert_2}{\Vert \mathbf{y} \Vert_2} + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} + $$ + + Args: + reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean". + weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. + + Examples: + >>> import ppsci + >>> loss = ppsci.loss.L2RelLoss() + """ + + def __init__( + self, + reduction: Literal["mean", "sum"] = "mean", + weight: Optional[Union[float, Dict[str, float]]] = None, + var_weight=None, + ): + if reduction not in ["mean", "sum"]: + raise ValueError( + f"reduction should be 'mean' or 'sum', but got {reduction}" + ) + super().__init__(reduction, weight) + self.var_weight = var_weight + + def rel_loss(self, x, y): + batch_size = x.shape[0] + x_ = x.reshape((batch_size, -1)) + y_ = y.reshape((batch_size, -1)) + + if self.var_weight is not None: + var_weight = paddle.to_tensor(self.var_weight, place=x.place) + diff = (x - y) ** 2 * var_weight.reshape([1, -1, 1, 1]) + diff_norms = paddle.sum(diff, axis=(1, 2, 3)) ** 0.5 + + diff_norms = paddle.norm(x_ - y_, p=2, axis=1) + y_norms = paddle.norm(y_, p=2, axis=1) + else: + diff_norms = paddle.norm(x_ - y_, p=2, axis=1) + y_norms = paddle.norm(y_, p=2, axis=1) + return diff_norms / y_norms + + def forward(self, output_dict, label_dict, weight_dict=None): + losses = 0 + for key in label_dict: + loss = self.rel_loss(output_dict[key], label_dict[key]) + if weight_dict is not None: + loss *= weight_dict[key] + + if self.reduction == "sum": + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + + if isinstance(self.weight, float): + loss *= self.weight + elif isinstance(self.weight, dict) and key in self.weight: + loss *= self.weight[key] + + losses += loss + return losses diff --git a/jointContribution/yinglong/ppsci/metric/__init__.py b/jointContribution/yinglong/ppsci/metric/__init__.py new file mode 100644 index 0000000000..2683aa6cdf --- /dev/null +++ b/jointContribution/yinglong/ppsci/metric/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from ppsci.metric.anomaly_coef import LatitudeWeightedACC +from ppsci.metric.base import Metric + +# from ppsci.metric.l2_rel import L2Rel +from ppsci.metric.mae import MAE + +# from ppsci.metric.mse import MSE +# from ppsci.metric.rmse import RMSE +from ppsci.metric.rmse import LatitudeWeightedRMSE +from ppsci.utils import misc + +__all__ = [ + "LatitudeWeightedACC", + "Metric", + # "L2Rel", + "MAE", + # "MSE", + # "RMSE", + "LatitudeWeightedRMSE", + "build_metric", +] + + +def build_metric(cfg): + """Build metric. + + Args: + cfg (List[AttrDict]): List of metric config. + + Returns: + Dict[str, Metric]: Dict of callable metric object. + """ + cfg = copy.deepcopy(cfg) + + metric_dict = misc.PrettyOrderedDict() + for _item in cfg: + metric_cls = next(iter(_item.keys())) + metric_cfg = _item.pop(metric_cls) + metric = eval(metric_cls)(**metric_cfg) + metric_dict[metric_cls] = metric + return metric_dict diff --git a/jointContribution/yinglong/ppsci/metric/anomaly_coef.py b/jointContribution/yinglong/ppsci/metric/anomaly_coef.py new file mode 100644 index 0000000000..aa992b6fd5 --- /dev/null +++ b/jointContribution/yinglong/ppsci/metric/anomaly_coef.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import paddle + +from ppsci.metric import base + + +class LatitudeWeightedACC(base.Metric): + r"""Latitude weighted anomaly correlation coefficient. + + $$ + metric = + \dfrac{\sum\limits_{m,n}{L_mX_{mn}Y_{mn}}}{\sqrt{\sum\limits_{m,n}{L_mX_{mn}^{2}}\sum\limits_{m,n}{L_mY_{mn}^{2}}}} + $$ + + $$ + L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)} + $$ + + $lat_m$ is the latitude at m. + $N_{lat}$ is the number of latitude set by `num_lat`. + + Args: + num_lat (Optional[int]): Number of latitude for compute weight, if is None, no weight applied. + mean (Optional[Union[np.array, Tuple[float, ...]]]): Mean of training data. Defaults to None. + keep_batch (bool, optional): Whether keep batch axis. Defaults to False. + variable_dict (Optional[Dict[str, int]]): Variable dictionary, the key is the name of a variable and + the value is its index. Defaults to None. + unlog (bool, optional): whether calculate expm1 for all elements in the array. Defaults to False. + scale (float, optional): The scale value used after expm1. Defaults to 1e-5. + + Examples: + >>> import numpy as np + >>> import ppsci + >>> mean = np.random.randn(20, 720, 1440) + >>> metric = ppsci.metric.LatitudeWeightedACC(720, mean=mean) + """ + + def __init__( + self, + num_lat: Optional[int] = None, + mean: Optional[Union[np.array, Tuple[float, ...]]] = None, + keep_batch: bool = False, + variable_dict: Optional[Dict[str, int]] = None, + unlog: bool = False, + scale: float = 1e-5, + ): + super().__init__(keep_batch) + self.num_lat = num_lat + self.mean = ( + None if mean is None else paddle.to_tensor(mean, paddle.get_default_dtype()) + ) + self.variable_dict = variable_dict + self.unlog = unlog + self.scale = scale + + self.weight = self.get_latitude_weight(num_lat) if num_lat is not None else None + + def get_latitude_weight(self, num_lat: int = 720): + lat_t = paddle.linspace(start=0, stop=1, num=num_lat) + lat_t = paddle.cos(3.1416 * (0.5 - lat_t)) + weight = num_lat * lat_t / paddle.sum(lat_t) + weight = weight.reshape((1, 1, -1, 1)) + return weight + + def scale_expm1(self, x: paddle.Tensor): + return self.scale * paddle.expm1(x) + + @paddle.no_grad() + def forward(self, output_dict, label_dict): + metric_dict = {} + for key in label_dict: + output = ( + self.scale_expm1(output_dict[key]) if self.unlog else output_dict[key] + ) + label = self.scale_expm1(label_dict[key]) if self.unlog else label_dict[key] + + if self.mean is not None: + output = output - self.mean + label = label - self.mean + + if self.weight is not None: + rmse = paddle.sum( + self.weight * output * label, axis=(-1, -2) + ) / paddle.sqrt( + paddle.sum(self.weight * output**2, axis=(-1, -2)) + * paddle.sum(self.weight * label**2, axis=(-1, -2)) + ) + else: + rmse = paddle.sum(output * label, axis=(-1, -2)) / paddle.sqrt( + paddle.sum(output**2, axis=(-1, -2)) + * paddle.sum(label**2, axis=(-1, -2)) + ) + + if self.variable_dict is not None: + for variable_name, idx in self.variable_dict.items(): + if self.keep_batch: + metric_dict[f"{key}.{variable_name}"] = rmse[:, idx] + else: + metric_dict[f"{key}.{variable_name}"] = rmse[:, idx].mean() + else: + if self.keep_batch: + metric_dict[key] = rmse.mean(axis=1) + else: + metric_dict[key] = rmse.mean() + return metric_dict diff --git a/jointContribution/yinglong/ppsci/metric/base.py b/jointContribution/yinglong/ppsci/metric/base.py new file mode 100644 index 0000000000..6a5bb102df --- /dev/null +++ b/jointContribution/yinglong/ppsci/metric/base.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import nn + + +class Metric(nn.Layer): + """Base class for metric.""" + + def __init__(self, keep_batch: bool = False): + super().__init__() + self.keep_batch = keep_batch diff --git a/jointContribution/yinglong/ppsci/metric/mae.py b/jointContribution/yinglong/ppsci/metric/mae.py new file mode 100644 index 0000000000..cd8e28f9ba --- /dev/null +++ b/jointContribution/yinglong/ppsci/metric/mae.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F + +from ppsci.metric import base + + +class MAE(base.Metric): + r"""Mean absolute error. + + $$ + metric = \dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_1 + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} + $$ + + Args: + keep_batch (bool, optional): Whether keep batch axis. Defaults to False. + + Examples: + >>> import ppsci + >>> metric = ppsci.metric.MAE() + """ + + def __init__(self, keep_batch: bool = False): + super().__init__(keep_batch) + + @paddle.no_grad() + def forward(self, output_dict, label_dict): + metric_dict = {} + for key in label_dict: + mae = F.l1_loss(output_dict[key], label_dict[key], "none") + if self.keep_batch: + metric_dict[key] = mae.mean(axis=tuple(range(1, mae.ndim))) + else: + metric_dict[key] = mae.mean() + + return metric_dict diff --git a/jointContribution/yinglong/ppsci/metric/rmse.py b/jointContribution/yinglong/ppsci/metric/rmse.py new file mode 100644 index 0000000000..ec7238c0d8 --- /dev/null +++ b/jointContribution/yinglong/ppsci/metric/rmse.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import paddle +import paddle.nn.functional as F + +from ppsci.metric import base + + +class LatitudeWeightedRMSE(base.Metric): + r"""Latitude weighted root mean square error. + + $$ + metric =\sqrt{\dfrac{1}{MN}\sum\limits_{m=1}^{M}\sum\limits_{n=1}^{N}L_m(X_{mn}-Y_{mn})^{2}} + $$ + + $$ + L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)} + $$ + + $lat_m$ is the latitude at m. + $N_{lat}$ is the number of latitude set by `num_lat`. + + Args: + num_lat (int): : Number of latitude for compute weight, if is None, no weight applied. + std (Optional[Union[np.array, Tuple[float, ...]]]): Standard Deviation of training dataset. Defaults to None. + keep_batch (bool, optional): Whether keep batch axis. Defaults to False. + variable_dict (Optional[Dict[str, int]]): Variable dictionary, the key is the name of a variable and + the value is its index. Defaults to None. + unlog (bool, optional): whether calculate expm1 for all elements in the array. Defaults to False. + scale (float, optional): The scale value used after expm1. Defaults to 1e-5. + + Examples: + >>> import numpy as np + >>> import ppsci + >>> std = np.random.randn(20, 1, 1) + >>> metric = ppsci.metric.LatitudeWeightedRMSE(720, std=std) + """ + + def __init__( + self, + num_lat: Optional[int] = None, + std: Optional[Union[np.array, Tuple[float, ...]]] = None, + keep_batch: bool = False, + variable_dict: Dict[str, int] = None, + unlog: bool = False, + scale: float = 1e-5, + ): + super().__init__(keep_batch) + self.num_lat = num_lat + self.std = ( + None + if std is None + else paddle.to_tensor(std, paddle.get_default_dtype()).reshape((1, -1)) + ) + self.variable_dict = variable_dict + self.unlog = unlog + self.scale = scale + self.weight = self.get_latitude_weight(num_lat) if num_lat is not None else None + + def get_latitude_weight(self, num_lat: int = 720): + lat_t = paddle.linspace(start=0, stop=1, num=num_lat) + lat_t = paddle.cos(3.1416 * (0.5 - lat_t)) + weight = num_lat * lat_t / paddle.sum(lat_t) + weight = weight.reshape((1, 1, -1, 1)) + return weight + + def scale_expm1(self, x: paddle.Tensor): + return self.scale * paddle.expm1(x) + + @paddle.no_grad() + def forward(self, output_dict, label_dict): + metric_dict = {} + for key in label_dict: + output = ( + self.scale_expm1(output_dict[key]) if self.unlog else output_dict[key] + ) + label = self.scale_expm1(label_dict[key]) if self.unlog else label_dict[key] + + mse = F.mse_loss(output, label, "none") + if self.weight is not None: + rmse = (mse * self.weight).mean(axis=(-1, -2)) ** 0.5 + else: + rmse = mse.mean(axis=(-1, -2)) ** 0.5 + if self.std is not None: + rmse = rmse * self.std + if self.variable_dict is not None: + for variable_name, idx in self.variable_dict.items(): + metric_dict[f"{key}.{variable_name}"] = ( + rmse[:, idx] if self.keep_batch else rmse[:, idx].mean() + ) + else: + metric_dict[key] = rmse.mean(axis=1) if self.keep_batch else rmse.mean() + + return metric_dict diff --git a/jointContribution/yinglong/ppsci/optimizer/__init__.py b/jointContribution/yinglong/ppsci/optimizer/__init__.py new file mode 100644 index 0000000000..54c64f896e --- /dev/null +++ b/jointContribution/yinglong/ppsci/optimizer/__init__.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from ppsci.optimizer import lr_scheduler +from ppsci.optimizer.optimizer import LBFGS +from ppsci.optimizer.optimizer import SGD +from ppsci.optimizer.optimizer import Adam +from ppsci.optimizer.optimizer import AdamW +from ppsci.optimizer.optimizer import Momentum +from ppsci.optimizer.optimizer import OptimizerList +from ppsci.optimizer.optimizer import RMSProp + +__all__ = [ + "LBFGS", + "SGD", + "Adam", + "AdamW", + "Momentum", + "RMSProp", + "OptimizerList", + "lr_scheduler", +] + + +def build_lr_scheduler(cfg, epochs, iters_per_epoch): + """Build learning rate scheduler. + + Args: + cfg (AttrDict): Learing rate scheduler config. + epochs (int): Total epochs. + iters_per_epoch (int): Number of iterations of one epoch. + + Returns: + LRScheduler: Learing rate scheduler. + """ + cfg = copy.deepcopy(cfg) + cfg.update({"epochs": epochs, "iters_per_epoch": iters_per_epoch}) + lr_scheduler_cls = cfg.pop("name") + lr_scheduler_ = eval(lr_scheduler_cls)(**cfg) + return lr_scheduler_() + + +def build_optimizer(cfg, model_list, epochs, iters_per_epoch): + """Build optimizer and learing rate scheduler + + Args: + cfg (AttrDict): Learing rate scheduler config. + model_list (Tuple[nn.Layer, ...]): Tuple of model(s). + epochs (int): Total epochs. + iters_per_epoch (int): Number of iterations of one epoch. + + Returns: + Optimizer, LRScheduler: Optimizer and learing rate scheduler. + """ + # build lr_scheduler + cfg = copy.deepcopy(cfg) + lr_cfg = cfg.pop("lr") + if isinstance(lr_cfg, float): + lr_scheduler = lr_cfg + else: + lr_scheduler = build_lr_scheduler(lr_cfg, epochs, iters_per_epoch) + + # build optimizer + opt_cls = cfg.pop("name") + optimizer = eval(opt_cls)(learning_rate=lr_scheduler, **cfg)(model_list) + + if isinstance(lr_scheduler, float): + return optimizer, None + return optimizer, lr_scheduler diff --git a/jointContribution/yinglong/ppsci/optimizer/lr_scheduler.py b/jointContribution/yinglong/ppsci/optimizer/lr_scheduler.py new file mode 100644 index 0000000000..3ac7be0a26 --- /dev/null +++ b/jointContribution/yinglong/ppsci/optimizer/lr_scheduler.py @@ -0,0 +1,646 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import math +from typing import Tuple +from typing import Union + +from paddle.optimizer import lr + +from ppsci.utils import logger + +__all__ = [ + "Linear", + "Cosine", + "Step", + "Piecewise", + "MultiStepDecay", + "ExponentialDecay", + "CosineWarmRestarts", +] + + +class LRBase: + """Base class for custom learning rates. + + Args: + epochs (int): total epoch(s). + iters_per_epoch (int): number of iterations within an epoch. + learning_rate (float): learning rate. + warmup_epoch (int): number of warmup epochs. + warmup_start_lr (float): start learning rate within warmup. + last_epoch (int): last epoch. + by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter. + verbose (bool): If True, prints a message to stdout for each update. Defaults to False. + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + learning_rate: float, + warmup_epoch: int, + warmup_start_lr: float, + last_epoch: int, + by_epoch: bool, + verbose: bool = False, + ) -> None: + """Initialize and record the necessary parameters""" + super().__init__() + if warmup_epoch >= epochs: + msg = ( + "When using warm up, the value of 'Global.epochs' should be greater " + "than value of 'Optimizer.lr.warmup_epoch'. The value of " + f"'Optimizer.lr.warmup_epoch' has been set to {epochs}." + ) + logger.warning(msg) + warmup_epoch = epochs + self.epochs = epochs + self.iters_per_epoch = iters_per_epoch + self.learning_rate = learning_rate + self.warmup_epoch = warmup_epoch + self.warmup_steps = ( + self.warmup_epoch + if by_epoch + else round(self.warmup_epoch * self.iters_per_epoch) + ) + self.warmup_start_lr = warmup_start_lr + self.last_epoch = last_epoch + self.by_epoch = by_epoch + self.verbose = verbose + + @abc.abstractmethod + def __call__(self, *kargs, **kwargs) -> lr.LRScheduler: + """Generate an learning rate scheduler. + + Returns: + lr.LinearWarmup: learning rate scheduler. + """ + pass + + def linear_warmup( + self, learning_rate: Union[float, lr.LRScheduler] + ) -> lr.LinearWarmup: + """Add an Linear Warmup before learning_rate. + + Args: + learning_rate (Union[float, lr.LRScheduler]): original learning rate without + warmup. + + Returns: + lr.LinearWarmup: learning rate scheduler with warmup. + """ + warmup_lr = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_steps, + start_lr=self.warmup_start_lr, + end_lr=self.learning_rate, + last_epoch=self.last_epoch, + verbose=self.verbose, + ) + return warmup_lr + + +class Constant(lr.LRScheduler): + """Constant learning rate Class implementation. + + Args: + learning_rate (float): The initial learning rate. + last_epoch (int, optional): The index of last epoch. Default: -1. + """ + + def __init__(self, learning_rate: float, last_epoch: int = -1): + self.learning_rate = learning_rate + self.last_epoch = last_epoch + super().__init__() + + def get_lr(self) -> float: + """Always return the same learning rate""" + return self.learning_rate + + +class Linear(LRBase): + """Linear learning rate decay. + + Args: + epochs (int): total epoch(s). + iters_per_epoch (int): number of iterations within an epoch. + learning_rate (float): learning rate. + end_lr (float, optional): The minimum final learning rate. Defaults to 0.0. + power (float, optional): Power of polynomial. Defaults to 1.0. + warmup_epoch (int): number of warmup epochs. + warmup_start_lr (float): start learning rate within warmup. + last_epoch (int): last epoch. + by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter. + + Examples: + >>> import ppsci + >>> lr = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001) + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + learning_rate: float, + end_lr: float = 0.0, + power: float = 1.0, + cycle: bool = False, + warmup_epoch: int = 0, + warmup_start_lr: float = 0.0, + last_epoch: int = -1, + by_epoch: bool = False, + ): + super().__init__( + epochs, + iters_per_epoch, + learning_rate, + warmup_epoch, + warmup_start_lr, + last_epoch, + by_epoch, + ) + self.decay_steps = (epochs - self.warmup_epoch) * iters_per_epoch + self.end_lr = end_lr + self.power = power + self.cycle = cycle + self.warmup_steps = round(self.warmup_epoch * iters_per_epoch) + if self.by_epoch: + self.decay_steps = self.epochs - self.warmup_epoch + + def __call__(self): + learning_rate = ( + lr.PolynomialDecay( + learning_rate=self.learning_rate, + decay_steps=self.decay_steps, + end_lr=self.end_lr, + power=self.power, + cycle=self.cycle, + last_epoch=self.last_epoch, + ) + if self.decay_steps > 0 + else Constant(self.learning_rate) + ) + + if self.warmup_steps > 0: + learning_rate = self.linear_warmup(learning_rate) + + setattr(learning_rate, "by_epoch", self.by_epoch) + return learning_rate + + +class ExponentialDecay(LRBase): + """ExponentialDecay learning rate decay. + + Args: + epochs (int): total epoch(s). + iters_per_epoch (int): number of iterations within an epoch. + learning_rate (float): learning rate. + warmup_epoch (int): number of warmup epochs. + warmup_start_lr (float): start learning rate within warmup. + last_epoch (int): last epoch. + by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter. + + Examples: + >>> import ppsci + >>> lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(10, 2, 1e-3, 0.95, 3) + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + learning_rate: float, + gamma: float, + decay_steps: int, + warmup_epoch: int = 0, + warmup_start_lr: float = 0.0, + last_epoch: int = -1, + by_epoch: bool = False, + ): + super().__init__( + epochs, + iters_per_epoch, + learning_rate, + warmup_epoch, + warmup_start_lr, + last_epoch, + by_epoch, + ) + self.decay_steps = decay_steps + self.gamma = gamma + self.warmup_steps = round(self.warmup_epoch * iters_per_epoch) + if self.by_epoch: + self.decay_steps /= iters_per_epoch + + def __call__(self): + learning_rate = lr.ExponentialDecay( + learning_rate=self.learning_rate, + gamma=self.gamma ** (1 / self.decay_steps), + last_epoch=self.last_epoch, + ) + + if self.warmup_steps > 0: + learning_rate = self.linear_warmup(learning_rate) + + setattr(learning_rate, "by_epoch", self.by_epoch) + return learning_rate + + +class Cosine(LRBase): + r"""Cosine learning rate decay. + + lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1) + + Args: + epochs (int): total epoch(s). + iters_per_epoch (int): number of iterations within an epoch. + learning_rate (float): learning rate. + eta_min (float, optional): Minimum learning rate. Defaults to 0.0. + warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0. + warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0. + last_epoch (int, optional): last epoch. Defaults to -1. + by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, + else by iter. Defaults to False. + + Examples: + >>> import ppsci + >>> lr = ppsci.optimizer.lr_scheduler.Cosine(10, 2, 1e-3) + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + learning_rate: float, + eta_min: float = 0.0, + warmup_epoch: int = 0, + warmup_start_lr: float = 0.0, + last_epoch: int = -1, + by_epoch: bool = False, + ): + super().__init__( + epochs, + iters_per_epoch, + learning_rate, + warmup_epoch, + warmup_start_lr, + last_epoch, + by_epoch, + ) + self.T_max = (self.epochs - self.warmup_epoch) * self.iters_per_epoch + self.eta_min = eta_min + if self.by_epoch: + self.T_max = self.epochs - self.warmup_epoch + + def __call__(self): + learning_rate = ( + lr.CosineAnnealingDecay( + learning_rate=self.learning_rate, + T_max=self.T_max, + eta_min=self.eta_min, + last_epoch=self.last_epoch, + ) + if self.T_max > 0 + else Constant(self.learning_rate) + ) + + if self.warmup_steps > 0: + learning_rate = self.linear_warmup(learning_rate) + + setattr(learning_rate, "by_epoch", self.by_epoch) + return learning_rate + + +class Step(LRBase): + """Step learning rate decay. + + Args: + epochs (int): total epoch(s). + iters_per_epoch (int): number of iterations within an epoch. + learning_rate (float): learning rate. + step_size (int): the interval to update. + gamma (float, optional): The Ratio that the learning rate will be reduced. + ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1. + warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0. + warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0. + last_epoch (int, optional): last epoch. Defaults to -1. + by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, + else by iter. Defaults to False. + + Examples: + >>> import ppsci + >>> lr = ppsci.optimizer.lr_scheduler.Step(10, 1, 1e-3, 2, 0.95) + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + learning_rate: float, + step_size: int, + gamma: float, + warmup_epoch: int = 0, + warmup_start_lr: float = 0.0, + last_epoch: int = -1, + by_epoch: bool = False, + ): + super().__init__( + epochs, + iters_per_epoch, + learning_rate, + warmup_epoch, + warmup_start_lr, + last_epoch, + by_epoch, + ) + self.step_size = step_size * iters_per_epoch + self.gamma = gamma + if self.by_epoch: + self.step_size = step_size + + def __call__(self): + learning_rate = lr.StepDecay( + learning_rate=self.learning_rate, + step_size=self.step_size, + gamma=self.gamma, + last_epoch=self.last_epoch, + ) + + if self.warmup_steps > 0: + learning_rate = self.linear_warmup(learning_rate) + + setattr(learning_rate, "by_epoch", self.by_epoch) + return learning_rate + + +class Piecewise(LRBase): + """Piecewise learning rate decay + + Args: + epochs (int): total epoch(s) + iters_per_epoch (int): number of iterations within an epoch + decay_epochs (Tuple[int, ...]): A list of steps numbers. The type of element in the + list is python int. + values (Tuple[float, ...]): Tuple of learning rate values that will be picked during + different epoch boundaries. + warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0. + warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0. + last_epoch (int, optional): last epoch. Defaults to -1. + by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, + else by iter. Defaults to False. + + Examples: + >>> import ppsci + >>> lr = ppsci.optimizer.lr_scheduler.Piecewise(10, 1, [2, 4], (1e-3, 1e-4)) + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + decay_epochs: Tuple[int, ...], + values: Tuple[float, ...], + warmup_epoch: int = 0, + warmup_start_lr: float = 0.0, + last_epoch: int = -1, + by_epoch: bool = False, + ): + super().__init__( + epochs, + iters_per_epoch, + values[0], + warmup_epoch, + warmup_start_lr, + last_epoch, + by_epoch, + ) + self.values = values + self.boundaries_steps = [e * iters_per_epoch for e in decay_epochs] + if self.by_epoch is True: + self.boundaries_steps = decay_epochs + + def __call__(self): + learning_rate = lr.PiecewiseDecay( + boundaries=self.boundaries_steps, + values=self.values, + last_epoch=self.last_epoch, + ) + + if self.warmup_steps > 0: + learning_rate = self.linear_warmup(learning_rate) + + setattr(learning_rate, "by_epoch", self.by_epoch) + return learning_rate + + +class MultiStepDecay(LRBase): + """MultiStepDecay learning rate decay + + Args: + epochs (int): total epoch(s) + iters_per_epoch (int): number of iterations within an epoch + learning_rate (float): learning rate + milestones (Tuple[int, ...]): Tuple of each boundaries. should be increasing. + gamma (float, optional): The Ratio that the learning rate will be reduced. + `new_lr = origin_lr * gamma`. It should be less than 1.0. Defaults to 0.1. + warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0. + warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0. + last_epoch (int, optional): last epoch. Defaults to -1. + by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, + else by iter. Defaults to False. + + Examples: + >>> import ppsci + >>> lr = ppsci.optimizer.lr_scheduler.MultiStepDecay(10, 1, 1e-3, (4, 5)) + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + learning_rate: float, + milestones: Tuple[int, ...], + gamma: float = 0.1, + warmup_epoch: int = 0, + warmup_start_lr: float = 0.0, + last_epoch: int = -1, + by_epoch: bool = False, + ): + super().__init__( + epochs, + iters_per_epoch, + learning_rate, + warmup_epoch, + warmup_start_lr, + last_epoch, + by_epoch, + ) + self.milestones = [x * iters_per_epoch for x in milestones] + self.gamma = gamma + if self.by_epoch: + self.milestones = milestones + + def __call__(self): + learning_rate = lr.MultiStepDecay( + learning_rate=self.learning_rate, + milestones=self.milestones, + gamma=self.gamma, + last_epoch=self.last_epoch, + ) + + if self.warmup_steps > 0: + learning_rate = self.linear_warmup(learning_rate) + + setattr(learning_rate, "by_epoch", self.by_epoch) + return learning_rate + + +class CosineAnnealingWarmRestarts(lr.LRScheduler): + """The implementation of cosine annealing schedule with warm restarts. + + Args: + learning_rate (float): Learning rate + T_0 (int): Number of iterations for the first restart. + T_mult (int, optional): A factor increases T_i after a restart. Defaults to 1. + eta_min (float, optional): Minimum learning rate. Defaults to 0. + last_epoch (int, optional): The index of last epoch. Defaults to -1. + verbose (bool, optional): If `True`, prints a message to stdout for each update. Defaults to False. + """ + + def __init__( + self, + learning_rate: float, + T_0: int, + T_mult: int = 1, + eta_min: float = 0.0, + last_epoch: int = -1, + verbose: bool = False, + ): + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError(f"Expected positive integer T_0, but got {T_0}") + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return ( + self.eta_min + + (self.base_lr - self.eta_min) + * (1 + math.cos(math.pi * self.T_cur / self.T_i)) + / 2 + ) + + def step(self, epoch=None): + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError(f"Expected non-negative epoch, but got {epoch}") + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int( + math.log( + (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult + ) + ) + self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / ( + self.T_mult - 1 + ) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + self.last_lr = self.get_lr() + + +class CosineWarmRestarts(LRBase): + """Set the learning rate using a cosine annealing schedule with warm restarts. + + Args: + epochs (int): Total epoch(s) + iters_per_epoch (int): Number of iterations within an epoch + learning_rate (float): Learning rate + T_0 (int): Number of iterations for the first restart. + T_mult (int): A factor increases T_i after a restart + eta_min (float, optional): Minimum learning rate. Defaults to 0.0. + warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0. + warmup_start_lr (float, optional): Start learning rate within warmup. Defaults to 0.0. + last_epoch (int, optional): Last epoch. Defaults to -1. + by_epoch (bool, optional): Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False. + + Examples: + >>> import ppsci + >>> lr = ppsci.optimizer.lr_scheduler.CosineWarmRestarts(20, 1, 1e-3, 14, 2) + """ + + def __init__( + self, + epochs: int, + iters_per_epoch: int, + learning_rate: float, + T_0: int, + T_mult: int, + eta_min: float = 0.0, + warmup_epoch: int = 0, + warmup_start_lr: float = 0.0, + last_epoch: int = -1, + by_epoch: bool = False, + ): + super().__init__( + epochs, + iters_per_epoch, + learning_rate, + warmup_epoch, + warmup_start_lr, + last_epoch, + by_epoch, + ) + self.T_0 = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + if self.by_epoch is False: + self.T_0 = T_0 * iters_per_epoch + + def __call__(self): + learning_rate = CosineAnnealingWarmRestarts( + learning_rate=self.learning_rate, + T_0=self.T_0, + T_mult=self.T_mult, + eta_min=self.eta_min, + last_epoch=self.last_epoch, + verbose=self.verbose, + ) + + if self.warmup_steps > 0: + learning_rate = self.linear_warmup(learning_rate) + + setattr(learning_rate, "by_epoch", self.by_epoch) + return learning_rate diff --git a/jointContribution/yinglong/ppsci/optimizer/optimizer.py b/jointContribution/yinglong/ppsci/optimizer/optimizer.py new file mode 100644 index 0000000000..54679d265f --- /dev/null +++ b/jointContribution/yinglong/ppsci/optimizer/optimizer.py @@ -0,0 +1,513 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from paddle import nn +from paddle import optimizer as optim +from paddle import regularizer +from paddle.incubate import optimizer as incubate_optim +from typing_extensions import Literal + +from ppsci.utils import logger +from ppsci.utils import misc + +if TYPE_CHECKING: + import paddle + +__all__ = ["SGD", "Momentum", "Adam", "RMSProp", "AdamW", "LBFGS", "OptimizerList"] + + +class SGD: + """Stochastic Gradient Descent. + + Args: + learning_rate (Union[float, optim.lr.LRScheduler], optional): The learning rate + used to update parameter(s). Defaults to 0.001. + weight_decay (Optional[Union[float, regularizer.L1Decay, regularizer.L2Decay]]): + Regularization strategy. Defaults to None. + grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): + Gradient cliping strategy. Defaults to None. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt = ppsci.optimizer.SGD(1e-3)((model,)) + """ + + def __init__( + self, + learning_rate: Union[float, optim.lr.LRScheduler] = 0.001, + weight_decay: Optional[ + Union[float, regularizer.L1Decay, regularizer.L2Decay] + ] = None, + grad_clip: Optional[ + Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm] + ] = None, + ): + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.grad_clip = grad_clip + + def __call__(self, model_list: Tuple[nn.Layer, ...]): + # model_list is None in static graph + parameters = ( + sum([m.parameters() for m in model_list], []) if model_list else None + ) + opt = optim.SGD( + learning_rate=self.learning_rate, + parameters=parameters, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + ) + return opt + + +class Momentum: + """Simple Momentum optimizer with velocity state. + + Args: + learning_rate (Union[float, optim.lr.LRScheduler]): The learning rate + used to update parameter(s). + momentum (float): Momentum factor. + weight_decay (Optional[Union[float, regularizer.L1Decay, regularizer.L2Decay]]): + Regularization strategy. Defaults to None. + grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): + Gradient cliping strategy. Defaults to None. + use_nesterov (bool, optional): Whether to use nesterov momentum. Defaults to False. + no_weight_decay_name (Optional[str]): List of names of no weight decay parameters split by white space. Defaults to None. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt = ppsci.optimizer.Momentum(1e-3, 0.9)((model,)) + """ + + def __init__( + self, + learning_rate: Union[float, optim.lr.LRScheduler], + momentum: float, + weight_decay: Optional[ + Union[float, regularizer.L1Decay, regularizer.L2Decay] + ] = None, + grad_clip: Optional[ + Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm] + ] = None, + use_nesterov: bool = False, + no_weight_decay_name: Optional[str] = None, + ): + super().__init__() + self.learning_rate = learning_rate + self.momentum = momentum + self.weight_decay = weight_decay + self.grad_clip = grad_clip + self.use_nesterov = use_nesterov + self.no_weight_decay_name_list = ( + no_weight_decay_name.split() if no_weight_decay_name else [] + ) + + def __call__(self, model_list: Tuple[nn.Layer, ...]): + # model_list is None in static graph + parameters = None + if len(self.no_weight_decay_name_list) > 0: + params_with_decay = [] + params_without_decay = [] + for m in model_list: + params = [ + p + for n, p in m.named_parameters() + if not any(nd in n for nd in self.no_weight_decay_name_list) + ] + params_with_decay.extend(params) + params = [ + p + for n, p in m.named_parameters() + if any(nd in n for nd in self.no_weight_decay_name_list) + ] + params_without_decay.extend(params) + parameters = [ + {"params": params_with_decay, "weight_decay": self.weight_decay}, + {"params": params_without_decay, "weight_decay": 0.0}, + ] + else: + parameters = ( + sum([m.parameters() for m in model_list], []) if model_list else None + ) + opt = optim.Momentum( + learning_rate=self.learning_rate, + momentum=self.momentum, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + use_nesterov=self.use_nesterov, + parameters=parameters, + ) + if hasattr(opt, "_use_multi_tensor"): + opt = optim.Momentum( + learning_rate=self.learning_rate, + momentum=self.momentum, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + parameters=parameters, + use_nesterov=self.use_nesterov, + use_multi_tensor=True, + ) + return opt + + +class Adam: + """Adam: A Method for Stochastic Optimization. + + Args: + learning_rate (Union[float, optim.lr.LRScheduler], optional): The learning rate + used to update parameter(s). Defaults to 0.001. + beta1 (float, optional): The exponential decay rate for the 1st moment estimates. Defaults to 0.9. + beta2 (float, optional): The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + epsilon (float, optional): A small float value for numerical stability. Defaults to 1e-08. + weight_decay (Optional[Union[float, regularizer.L1Decay, regularizer.L2Decay]]): Regularization strategy. Defaults to None. + grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): Gradient cliping strategy. Defaults to None. + lazy_mode (bool, optional): Whether to enable lazy mode for moving-average. Defaults to False. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt = ppsci.optimizer.Adam(1e-3)((model,)) + """ + + def __init__( + self, + learning_rate: Union[float, optim.lr.LRScheduler] = 0.001, + beta1: float = 0.9, + beta2: float = 0.999, + epsilon: float = 1e-08, + weight_decay: Optional[ + Union[float, regularizer.L1Decay, regularizer.L2Decay] + ] = None, + grad_clip: Optional[ + Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm] + ] = None, + lazy_mode: bool = False, + ): + self.learning_rate = learning_rate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.grad_clip = grad_clip + self.lazy_mode = lazy_mode + + def __call__(self, model_list: Tuple[nn.Layer, ...]): + # model_list is None in static graph + parameters = ( + sum([m.parameters() for m in model_list], []) if model_list else None + ) + opt = optim.Adam( + learning_rate=self.learning_rate, + beta1=self.beta1, + beta2=self.beta2, + epsilon=self.epsilon, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + lazy_mode=self.lazy_mode, + parameters=parameters, + ) + return opt + + +class LBFGS: + """The L-BFGS is a quasi-Newton method for solving an unconstrained optimization + problem over a differentiable function. Closely related is the Newton method for minimization. + + Args: + learning_rate (float, optional): The learning rate + used to update parameter(s). Defaults to 1.0. + max_iter (int, optional): Maximal number of iterations per optimization step. + Defaults to 1. + max_eval (Optional[int]): Maximal number of function evaluations per + optimization step. Defaults to None. + tolerance_grad (float, optional): Termination tolerance on first order optimality. + Defaults to 1e-07. + tolerance_change (float, optional): termination tolerance on function + value/parameterchanges. Defaults to 1e-09. + history_size (int, optional): Update history size. Defaults to 100. + line_search_fn (Optional[Literal["strong_wolfe"]]): Either 'strong_wolfe' or None. + Defaults to "strong_wolfe". + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt = ppsci.optimizer.LBFGS(1e-3)((model,)) + """ + + def __init__( + self, + learning_rate: float = 1.0, + max_iter: int = 1, + max_eval: Optional[int] = None, + tolerance_grad: float = 1e-07, + tolerance_change: float = 1e-09, + history_size: int = 100, + line_search_fn: Optional[Literal["strong_wolfe"]] = "strong_wolfe", + ): + self.lr = learning_rate + self.max_iter = max_iter + self.max_eval = max_eval + self.tolerance_grad = tolerance_grad + self.tolerance_change = tolerance_change + self.history_size = history_size + self.line_search_fn = line_search_fn + + def __call__(self, model_list: Tuple[nn.Layer, ...]): + # model_list is None in static graph + parameters = ( + sum([m.parameters() for m in model_list], []) if model_list else None + ) + try: + opt = getattr(optim, "LBFGS")( + learning_rate=self.lr, + max_iter=self.max_iter, + max_eval=self.max_eval, + tolerance_grad=self.tolerance_grad, + tolerance_change=self.tolerance_change, + history_size=self.history_size, + line_search_fn=self.line_search_fn, + parameters=parameters, + ) + except AttributeError: + opt = getattr(incubate_optim, "LBFGS")( + learning_rate=self.lr, + max_iter=self.max_iter, + max_eval=self.max_eval, + tolerance_grad=self.tolerance_grad, + tolerance_change=self.tolerance_change, + history_size=self.history_size, + line_search_fn=self.line_search_fn, + parameters=parameters, + ) + return opt + + +class RMSProp: + """Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method. + + Args: + learning_rate (Union[float, optim.lr.LRScheduler]): The learning rate + used to update parameter(s) + rho (float, optional): Factor ρ in equation. Defaults to 0.95. + epsilon (float, optional): Factor ϵ in equation as a smoothing term. Defaults to 1e-6. + momentum (float, optional):β in equation is the momentum term. Defaults to 0.0. + weight_decay (Optional[Union[float, regularizer.L1Decay, regularizer.L2Decay]]): + Regularization strategy. Defaults to None. + grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): + Gradient cliping strategy. Defaults to None. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt = ppsci.optimizer.RMSProp(1e-3)((model,)) + """ + + def __init__( + self, + learning_rate: Union[float, optim.lr.LRScheduler], + rho: float = 0.95, + epsilon: float = 1e-6, + momentum: float = 0.0, + weight_decay: Optional[ + Union[float, regularizer.L1Decay, regularizer.L2Decay] + ] = None, + grad_clip: Optional[ + Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm] + ] = None, + ): + super().__init__() + self.learning_rate = learning_rate + self.momentum = momentum + self.rho = rho + self.epsilon = epsilon + self.weight_decay = weight_decay + self.grad_clip = grad_clip + + def __call__(self, model_list: Tuple[nn.Layer, ...]): + # model_list is None in static graph + parameters = ( + sum([m.parameters() for m in model_list], []) if model_list else None + ) + opt = optim.RMSProp( + learning_rate=self.learning_rate, + momentum=self.momentum, + rho=self.rho, + epsilon=self.epsilon, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + parameters=parameters, + ) + return opt + + +class AdamW: + """AdamW is implemented based on DECOUPLED WEIGHT DECAY REGULARIZATION. + + Args: + learning_rate (Union[float, optim.lr.LRScheduler], optional): The learning rate + used to update parameter(s). Defaults to 0.001. + beta1 (float, optional): The exponential decay rate for the 1st moment estimates. Defaults to 0.9. + beta2 (float, optional): The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + epsilon (float, optional): A small float value for numerical stability. Defaults to 1e-8. + weight_decay (float, optional): Regularization cofficient. Defaults to 0.01. + grad_clip (Optional[Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm]]): Gradient cliping strategy. Defaults to None. + no_weight_decay_name (Optional[str]): List of names of no weight decay parameters split by white space. Defaults to None. + one_dim_param_no_weight_decay (bool, optional): Apply no weight decay on 1-D parameter(s). Defaults to False. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt = ppsci.optimizer.AdamW(1e-3)((model,)) + """ + + def __init__( + self, + learning_rate: Union[float, optim.lr.LRScheduler] = 0.001, + beta1: float = 0.9, + beta2: float = 0.999, + epsilon: float = 1e-8, + weight_decay: float = 0.001, + grad_clip: Optional[ + Union[nn.ClipGradByNorm, nn.ClipGradByValue, nn.ClipGradByGlobalNorm] + ] = None, + no_weight_decay_name: Optional[str] = None, + one_dim_param_no_weight_decay: bool = False, + ): + super().__init__() + self.learning_rate = learning_rate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.grad_clip = grad_clip + self.weight_decay = weight_decay + self.no_weight_decay_name_list = ( + no_weight_decay_name.split() if no_weight_decay_name else [] + ) + self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay + + def __call__(self, model_list: Tuple[nn.Layer, ...]): + # model_list is None in static graph + parameters = ( + sum([m.parameters() for m in model_list], []) if model_list else None + ) + + # TODO(gaotingquan): model_list is None when in static graph, "no_weight_decay" not work. + if model_list is None: + if ( + self.one_dim_param_no_weight_decay + or len(self.no_weight_decay_name_list) != 0 + ): + msg = '"AdamW" does not support setting "no_weight_decay" in static graph. Please use dynamic graph.' + logger.error(Exception(msg)) + raise Exception(msg) + + self.no_weight_decay_param_name_list = ( + [ + p.name + for model in model_list + for n, p in model.named_parameters() + if any(nd in n for nd in self.no_weight_decay_name_list) + ] + if model_list + else [] + ) + + if self.one_dim_param_no_weight_decay: + self.no_weight_decay_param_name_list += ( + [ + p.name + for model in model_list + for n, p in model.named_parameters() + if len(p.shape) == 1 + ] + if model_list + else [] + ) + + opt = optim.AdamW( + learning_rate=self.learning_rate, + beta1=self.beta1, + beta2=self.beta2, + epsilon=self.epsilon, + parameters=parameters, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + apply_decay_param_fun=self._apply_decay_param_fun, + ) + return opt + + def _apply_decay_param_fun(self, name): + return name not in self.no_weight_decay_param_name_list + + +class OptimizerList: + """OptimizerList which wrap more than one optimizer. + NOTE: LBFGS is not supported yet. + + Args: + optimizer_list (Tuple[optim.Optimizer, ...]): Optimizers listed in a tuple. + + Examples: + >>> import ppsci + >>> model1 = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt1 = ppsci.optimizer.Adam(1e-3)((model1,)) + >>> model2 = ppsci.arch.MLP(("y",), ("v",), 5, 20) + >>> opt2 = ppsci.optimizer.Adam(1e-3)((model2,)) + >>> opt = ppsci.optimizer.OptimizerList((opt1, opt2)) + """ + + def __init__(self, optimizer_list: Tuple[optim.Optimizer, ...]): + super().__init__() + self._opt_list = optimizer_list + if "LBFGS" in set(misc.typename(opt) for opt in optimizer_list): + raise ValueError("LBFGS is not supported in OptimizerList yet.") + + def step(self): + for opt in self._opt_list: + opt.step() + + def clear_grad(self): + for opt in self._opt_list: + opt.clear_grad() + + def get_lr(self) -> float: + """Return learning rate of first optimizer""" + return self._opt_list[0].get_lr() + + def set_state_dict(self, state_dicts: List[Dict[str, "paddle.Tensor"]]): + for i, opt in enumerate(self._opt_list): + opt.set_state_dict(state_dicts[i]) + + def state_dict(self) -> List[Dict[str, "paddle.Tensor"]]: + state_dicts = [opt.state_dict() for opt in self._opt_list] + return state_dicts + + def __len__(self) -> int: + return len(self._opt_list) + + def __getitem__(self, idx): + return self._opt_list[idx] + + def __setitem__(self, idx, opt): + raise NotImplementedError("Can not modify any item in OptimizerList.") diff --git a/jointContribution/yinglong/ppsci/solver/__init__.py b/jointContribution/yinglong/ppsci/solver/__init__.py new file mode 100644 index 0000000000..03f97bc2d9 --- /dev/null +++ b/jointContribution/yinglong/ppsci/solver/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppsci.solver import eval +from ppsci.solver import train +from ppsci.solver import visu +from ppsci.solver.solver import Solver + +__all__ = [ + "eval", + "train", + "visu", + "Solver", +] diff --git a/jointContribution/yinglong/ppsci/solver/eval.py b/jointContribution/yinglong/ppsci/solver/eval.py new file mode 100644 index 0000000000..a6310374c7 --- /dev/null +++ b/jointContribution/yinglong/ppsci/solver/eval.py @@ -0,0 +1,270 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import TYPE_CHECKING + +import paddle +from paddle import io + +from ppsci.solver import printer +from ppsci.utils import misc +from ppsci.utils import profiler + +if TYPE_CHECKING: + from ppsci import solver + + +def _eval_by_dataset(solver: "solver.Solver", epoch_id: int, log_freq: int) -> float: + """Evaluate with computing metric on total samples. + + Args: + solver (solver.Solver): Main Solver. + epoch_id (int): Epoch id. + log_freq (int): Log evaluation information every `log_freq` steps. + + Returns: + float: Target metric computed during evaluation. + """ + target_metric: float = None + for _, _validator in solver.validator.items(): + all_input = misc.Prettydefaultdict(list) + all_output = misc.Prettydefaultdict(list) + all_label = misc.Prettydefaultdict(list) + if isinstance(_validator.data_loader, io.DataLoader): + num_samples = len(_validator.data_loader.dataset) + else: + num_samples = _validator.data_loader.num_samples + + loss_dict = misc.Prettydefaultdict(float) + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + for iter_id, batch in enumerate(_validator.data_loader, start=1): + input_dict, label_dict, weight_dict = batch + # profile code + # profiler.add_profiler_step(solver.cfg["profiler_options"]) + if iter_id == 5: + # 5 step for warmup + for key in solver.eval_time_info: + solver.eval_time_info[key].reset() + reader_cost = time.perf_counter() - reader_tic + for v in input_dict.values(): + v.stop_gradient = False + + # forward + with solver.autocast_context_manager( + solver.use_amp, solver.amp_level + ), solver.no_grad_context_manager(solver.eval_with_no_grad): + output_dict, validator_loss = solver.forward_helper.eval_forward( + _validator.output_expr, + input_dict, + solver.model, + _validator, + label_dict, + weight_dict, + ) + + loss_dict[f"loss({_validator.name})"] = float(validator_loss) + + # collect batch data + for key, input in input_dict.items(): + all_input[key].append( + input.detach() + if solver.world_size == 1 + else misc.all_gather(input.detach()) + ) + for key, output in output_dict.items(): + all_output[key].append( + output.detach() + if solver.world_size == 1 + else misc.all_gather(output.detach()) + ) + for key, label in label_dict.items(): + all_label[key].append( + label.detach() + if solver.world_size == 1 + else misc.all_gather(label.detach()) + ) + + batch_cost = time.perf_counter() - batch_tic + solver.eval_time_info["reader_cost"].update(reader_cost) + solver.eval_time_info["batch_cost"].update(batch_cost) + batch_size = next(iter(input_dict.values())).shape[0] + printer.update_eval_loss(solver, loss_dict, batch_size) + if iter_id == 1 or iter_id % log_freq == 0: + printer.log_eval_info( + solver, + batch_size, + epoch_id, + len(_validator.data_loader), + iter_id, + ) + + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + + # concate all data and discard padded sample(s) + for key in all_input: + all_input[key] = paddle.concat(all_input[key]) + if len(all_input[key]) > num_samples: + all_input[key] = all_input[key][:num_samples] + for key in all_output: + all_output[key] = paddle.concat(all_output[key]) + if len(all_output[key]) > num_samples: + all_output[key] = all_output[key][:num_samples] + for key in all_label: + all_label[key] = paddle.concat(all_label[key]) + if len(all_label[key]) > num_samples: + all_label[key] = all_label[key][:num_samples] + + metric = misc.PrettyOrderedDict() + for metric_name, metric_func in _validator.metric.items(): + metric_dict = metric_func(all_output, all_label) + metric[metric_name] = metric_dict + for var_name, metric_value in metric_dict.items(): + metric_str = f"{metric_name}.{var_name}({_validator.name})" + if metric_str not in solver.eval_output_info: + solver.eval_output_info[metric_str] = misc.AverageMeter( + metric_str, ".5f" + ) + solver.eval_output_info[metric_str].update( + float(metric_value), num_samples + ) + + # use the first metric for return value + if target_metric is None: + tmp = metric + while isinstance(tmp, dict): + tmp = next(iter(tmp.values())) + target_metric = float(tmp) + + return target_metric + + +def _eval_by_batch(solver: "solver.Solver", epoch_id: int, log_freq: int) -> float: + """Evaluate with computing metric by batch, which is memory-efficient. + + Args: + solver (solver.Solver): Main Solver. + epoch_id (int): Epoch id. + log_freq (int): Log evaluation information every `log_freq` steps. + + Returns: + float: Target metric computed during evaluation. + """ + target_metric: float = None + for _, _validator in solver.validator.items(): + if isinstance(_validator.data_loader, io.DataLoader): + num_samples = len(_validator.data_loader.dataset) + else: + num_samples = _validator.data_loader.num_samples + + loss_dict = misc.Prettydefaultdict(float) + metric = misc.PrettyOrderedDict() + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + for iter_id, batch in enumerate(_validator.data_loader, start=1): + input_dict, label_dict, weight_dict,input_time = batch + # profile code + # profiler.add_profiler_step(solver.cfg["profiler_options"]) + if iter_id == 5: + # 5 step for warmup + for key in solver.eval_time_info: + solver.eval_time_info[key].reset() + reader_cost = time.perf_counter() - reader_tic + batch_size = next(iter(input_dict.values())).shape[0] + for v in input_dict.values(): + v.stop_gradient = False + + # forward + with solver.autocast_context_manager( + solver.use_amp, solver.amp_level + ), solver.no_grad_context_manager(solver.eval_with_no_grad): + output_dict, validator_loss = solver.forward_helper.eval_forward( + _validator.output_expr, + input_dict, + solver.model, + _validator, + label_dict, + weight_dict, + input_time, + ) + + loss_dict[f"loss({_validator.name})"] = float(validator_loss) + + # collect batch metric + for metric_name, metric_func in _validator.metric.items(): + metric_dict = metric_func(output_dict, label_dict) + if metric_name not in metric: + metric[metric_name] = misc.Prettydefaultdict(list) + for var_name, metric_value in metric_dict.items(): + metric[metric_name][var_name].append( + metric_value + if solver.world_size == 1 + else misc.all_gather(metric_value) + ) + + batch_cost = time.perf_counter() - batch_tic + solver.eval_time_info["reader_cost"].update(reader_cost) + solver.eval_time_info["batch_cost"].update(batch_cost) + printer.update_eval_loss(solver, loss_dict, batch_size) + if iter_id == 1 or iter_id % log_freq == 0: + printer.log_eval_info( + solver, + batch_size, + epoch_id, + len(_validator.data_loader), + iter_id, + ) + + reader_tic = time.perf_counter() + batch_tic = time.perf_counter() + + # concate all metric and discard metric of padded sample(s) + for metric_name, metric_dict in metric.items(): + for var_name, metric_value in metric_dict.items(): + metric_value = paddle.concat(metric_value)[:num_samples] + metric_value = float(metric_value.mean()) + metric[metric_name][var_name] = metric_value + metric_str = f"{metric_name}.{var_name}({_validator.name})" + if metric_str not in solver.eval_output_info: + solver.eval_output_info[metric_str] = misc.AverageMeter( + metric_str, ".5f" + ) + solver.eval_output_info[metric_str].update(metric_value, num_samples) + + # use the first metric for return value + if target_metric is None: + tmp = metric + while isinstance(tmp, dict): + tmp = next(iter(tmp.values())) + target_metric = tmp + + return target_metric + + +def eval_func(solver: "solver.Solver", epoch_id: int, log_freq: int) -> float: + """Evaluation function. + + Args: + solver (solver.Solver): Main Solver. + epoch_id (int): Epoch id. + log_freq (int): Log evaluation information every `log_freq` steps. + + Returns: + float: Target metric computed during evaluation. + """ + if solver.compute_metric_by_batch: + return _eval_by_batch(solver, epoch_id, log_freq) + return _eval_by_dataset(solver, epoch_id, log_freq) diff --git a/jointContribution/yinglong/ppsci/solver/printer.py b/jointContribution/yinglong/ppsci/solver/printer.py new file mode 100644 index 0000000000..a640b73b59 --- /dev/null +++ b/jointContribution/yinglong/ppsci/solver/printer.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime + +from ppsci.utils import logger +from ppsci.utils import misc + + +def update_train_loss(trainer, loss_dict, batch_size): + # update_output_info + for key in loss_dict: + if key not in trainer.train_output_info: + trainer.train_output_info[key] = misc.AverageMeter(key, "7.5f") + trainer.train_output_info[key].update(float(loss_dict[key]), batch_size) + + +def update_eval_loss(trainer, loss_dict, batch_size): + # update_output_info + for key in loss_dict: + if key not in trainer.eval_output_info: + trainer.eval_output_info[key] = misc.AverageMeter(key, "7.5f") + trainer.eval_output_info[key].update(float(loss_dict[key]), batch_size) + + +def log_train_info(trainer, batch_size, epoch_id, iter_id): + lr_msg = f"lr: {trainer.optimizer.get_lr():.8f}" + + metric_msg = ", ".join( + [ + f"{key}: {trainer.train_output_info[key].avg:.5f}" + for key in trainer.train_output_info + ] + ) + + time_msg = ", ".join( + [trainer.train_time_info[key].mean for key in trainer.train_time_info] + ) + + ips_msg = ( + f"ips: {batch_size / trainer.train_time_info['batch_cost'].avg:.5f} samples/s" + ) + + eta_sec = ( + (trainer.epochs - epoch_id + 1) * trainer.iters_per_epoch - iter_id + ) * trainer.train_time_info["batch_cost"].avg + eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec))):s}" + logger.info( + f"[Train][Epoch {epoch_id}/{trainer.epochs}]" + f"[Iter: {iter_id}/{trainer.iters_per_epoch}] {lr_msg}, " + f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}" + ) + + logger.scaler( + name="lr", + value=trainer.optimizer.get_lr(), + step=trainer.global_step, + writer=trainer.vdl_writer, + ) + + for key in trainer.train_output_info: + logger.scaler( + name=f"train_{key}", + value=trainer.train_output_info[key].avg, + step=trainer.global_step, + writer=trainer.vdl_writer, + ) + + +def log_eval_info(trainer, batch_size, epoch_id, iters_per_epoch, iter_id): + metric_msg = ", ".join( + [ + f"{key}: {trainer.eval_output_info[key].avg:.5f}" + for key in trainer.eval_output_info + ] + ) + + time_msg = ", ".join( + [trainer.eval_time_info[key].mean for key in trainer.eval_time_info] + ) + + ips_msg = ( + f"ips: {batch_size / trainer.eval_time_info['batch_cost'].avg:.5f}" f"samples/s" + ) + + eta_sec = (iters_per_epoch - iter_id) * trainer.eval_time_info["batch_cost"].avg + eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec))):s}" + logger.info( + f"[Eval][Epoch {epoch_id}][Iter: {iter_id}/{iters_per_epoch}] " + f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}" + ) + + for key in trainer.eval_output_info: + logger.scaler( + name=f"eval_{key}", + value=trainer.eval_output_info[key].avg, + step=trainer.global_step, + writer=trainer.vdl_writer, + ) diff --git a/jointContribution/yinglong/ppsci/solver/solver.py b/jointContribution/yinglong/ppsci/solver/solver.py new file mode 100644 index 0000000000..9622a2ed1d --- /dev/null +++ b/jointContribution/yinglong/ppsci/solver/solver.py @@ -0,0 +1,679 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import itertools +import os +import sys +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Union + +import numpy as np +import paddle +import paddle.distributed as dist +import visualdl as vdl +from packaging import version +from paddle import amp +from paddle import jit +from paddle import nn +from paddle import optimizer as optim +from paddle.distributed import fleet +from typing_extensions import Literal + +import ppsci +from ppsci.utils import config +from ppsci.utils import expression +from ppsci.utils import logger +from ppsci.utils import misc +from ppsci.utils import save_load + + +class Solver: + """Class for solver. + + Args: + model (nn.Layer): Model. + constraint (Optional[Dict[str, ppsci.constraint.Constraint]]): Constraint(s) applied on model. Defaults to None. + output_dir (Optional[str]): Output directory. Defaults to "./output/". + optimizer (Optional[optimizer.Optimizer]): Optimizer object. Defaults to None. + lr_scheduler (Optional[optimizer.lr.LRScheduler]): Learning rate scheduler. Defaults to None. + epochs (int, optional): Training epoch(s). Defaults to 5. + iters_per_epoch (int, optional): Number of iterations within an epoch. Defaults to 20. + update_freq (int, optional): Update frequency of parameters. Defaults to 1. + save_freq (int, optional): Saving frequency for checkpoint. Defaults to 0. + log_freq (int, optional): Logging frequency. Defaults to 10. + eval_during_train (bool, optional): Whether evaluate model during training. Defaults to False. + start_eval_epoch (int, optional): Epoch number evaluation applied begin after. Defaults to 1. + eval_freq (int, optional): Evaluation frequency. Defaults to 1. + seed (int, optional): Random seed. Defaults to 42. + vdl_writer (Optional[vdl.LogWriter]): VisualDL writer object. Defaults to None. + device (Literal["cpu", "gpu", "xpu"], optional): Runtime device. Defaults to "gpu". + equation (Optional[Dict[str, ppsci.equation.PDE]]): Equation dict. Defaults to None. + geom (Optional[Dict[str, ppsci.geometry.Geometry]]): Geometry dict. Defaults to None. + validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None. + visualizer (Optional[Dict[str, ppsci.visualize.Visualizer]]): Visualizer dict. Defaults to None. + use_amp (bool, optional): Whether use AMP. Defaults to False. + amp_level (Literal["O1", "O2", "O0"], optional): AMP level. Defaults to "O0". + pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None. + checkpoint_path (Optional[str]): Checkpoint path. Defaults to None. + compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluate. Defaults to False. + eval_with_no_grad (bool, optional): Whether set `stop_gradient=True` for every Tensor if no differentiation + involved during computation, generally for save GPU memory and accelerate computing. Defaults to False. + to_static (bool, optional): Whether enable to_static for forward pass. Defaults to False. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) + >>> opt = ppsci.optimizer.AdamW(1e-3)((model,)) + >>> geom = ppsci.geometry.Rectangle((0, 0), (1, 1)) + >>> pde_constraint = ppsci.constraint.InteriorConstraint( + ... {"u": lambda out: out["u"]}, + ... {"u": 0}, + ... geom, + ... { + ... "dataset": "IterableNamedArrayDataset", + ... "iters_per_epoch": 1, + ... "batch_size": 16, + ... }, + ... ppsci.loss.MSELoss("mean"), + ... name="EQ", + ... ) + >>> solver = ppsci.solver.Solver( + ... model, + ... {"EQ": pde_constraint}, + ... "./output", + ... opt, + ... None, + ... ) # doctest: +SKIP + """ + + def __init__( + self, + model: nn.Layer, + constraint: Optional[Dict[str, ppsci.constraint.Constraint]] = None, + output_dir: Optional[str] = "./output/", + optimizer: Optional[optim.Optimizer] = None, + lr_scheduler: Optional[optim.lr.LRScheduler] = None, + epochs: int = 5, + iters_per_epoch: int = 20, + update_freq: int = 1, + save_freq: int = 0, + log_freq: int = 10, + eval_during_train: bool = False, + start_eval_epoch: int = 1, + eval_freq: int = 1, + seed: int = 42, + vdl_writer: Optional[vdl.LogWriter] = None, + device: Literal["cpu", "gpu", "xpu"] = "gpu", + equation: Optional[Dict[str, ppsci.equation.PDE]] = None, + geom: Optional[Dict[str, ppsci.geometry.Geometry]] = None, + validator: Optional[Dict[str, ppsci.validate.Validator]] = None, + visualizer: Optional[Dict[str, ppsci.visualize.Visualizer]] = None, + use_amp: bool = False, + amp_level: Literal["O1", "O2", "O0"] = "O0", + pretrained_model_path: Optional[str] = None, + checkpoint_path: Optional[str] = None, + compute_metric_by_batch: bool = False, + eval_with_no_grad: bool = False, + to_static: bool = False, + ): + # set model + self.model = model + # set constraint + self.constraint = constraint + # set output directory + self.output_dir = output_dir + + # set optimizer + self.optimizer = optimizer + # set learning rate scheduler + self.lr_scheduler = lr_scheduler + + # set training hyper-parameter + self.epochs = epochs + self.iters_per_epoch = iters_per_epoch + # set update_freq for gradient accumulation + self.update_freq = update_freq + # set checkpoint saving frequency + self.save_freq = save_freq + # set logging frequency + self.log_freq = log_freq + + # set evaluation hyper-parameter + self.eval_during_train = eval_during_train + self.start_eval_epoch = start_eval_epoch + self.eval_freq = eval_freq + + # initialize traning log recorder for loss, time cost, metric, etc. + self.train_output_info = {} + self.train_time_info = { + "batch_cost": misc.AverageMeter("batch_cost", ".5f", postfix="s"), + "reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"), + } + + # initialize evaluation log recorder for loss, time cost, metric, etc. + self.eval_output_info = {} + self.eval_time_info = { + "batch_cost": misc.AverageMeter("batch_cost", ".5f", postfix="s"), + "reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"), + } + + # fix seed for reproducibility + self.seed = seed + + # set VisualDL tool + self.vdl_writer = vdl_writer + + # set running device + self.device = paddle.set_device(device) + # set equations for physics-driven or data-physics hybrid driven task, such as PINN + self.equation = equation + # set geometry for generating data + self.geom = {} if geom is None else geom + + # set validator + self.validator = validator + + # set visualizer + self.visualizer = visualizer + + # set automatic mixed precision(AMP) configuration + self.use_amp = use_amp + self.amp_level = amp_level + self.scaler = amp.GradScaler(True) if self.use_amp else None + + # load pretrained model, usually used for transfer learning + if pretrained_model_path is not None: + save_load.load_pretrain(self.model, pretrained_model_path, self.equation) + + # whether calculate metrics after each batch during evaluate + self.compute_metric_by_batch = compute_metric_by_batch + if validator is not None: + for metric in itertools.chain( + *[_v.metric.values() for _v in self.validator.values()] + ): + if metric.keep_batch ^ compute_metric_by_batch: + raise ValueError( + f"{misc.typename(metric)}.keep_batch should be " + f"{compute_metric_by_batch} when compute_metric_by_batch=" + f"{compute_metric_by_batch}." + ) + # whether set `stop_gradient=True` for every Tensor if no differentiation involved during computation + self.eval_with_no_grad = eval_with_no_grad + + # initialize an dict for tracking best metric during training + self.best_metric = { + "metric": float("inf"), + "epoch": 0, + } + # load model checkpoint, usually used for resume training + if checkpoint_path is not None: + loaded_metric = save_load.load_checkpoint( + checkpoint_path, self.model, self.optimizer, self.scaler, self.equation + ) + if isinstance(loaded_metric, dict): + self.best_metric.update(loaded_metric) + + # init logger without FileHandler if not initialized before + if logger._logger is None: + logger.init_logger("ppsci", None) + + # choosing an appropriate training function for different optimizers + if isinstance(self.optimizer, optim.LBFGS): + self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func + if self.update_freq != 1: + self.update_freq = 1 + logger.warning("Set update_freq to to 1 when using L-BFGS optimizer.") + else: + self.train_epoch_func = ppsci.solver.train.train_epoch_func + + # decorate model(s) and optimizer(s) for AMP + if self.use_amp: + self.model, self.optimizer = amp.decorate( + self.model, self.optimizer, self.amp_level + ) + + # wrap model and optimizer to parallel object + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + if self.world_size > 1: + # TODO(sensen): support different kind of DistributedStrategy + fleet.init(is_collective=True) + self.model = fleet.distributed_model(self.model) + if self.optimizer is not None: + self.optimizer = fleet.distributed_optimizer(self.optimizer) + logger.warning( + f"Detected world_size({self.world_size}) > 1, it is recommended to " + "scale up the learning rate and reduce the epochs or " + "iters_per_epoch according to the world_size both linearly." + ) + + self.global_step = 0 + + # log paddlepaddle's version + paddle_version = ( + paddle.__version__ + if version.Version(paddle.__version__) != version.Version("0.0.0") + else f"develop({paddle.version.commit[:7]})" + ) + logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}") + + self.forward_helper = expression.ExpressionSolver() + + # whether enable static for forward pass, default to Fals + jit.enable_to_static(to_static) + logger.info(f"Set to_static={to_static} for forward computation.") + + @staticmethod + def from_config(cfg: Dict[str, Any]) -> Solver: + """Initialize solver from given config. + + Args: + cfg (Dict[str, Any]): Dict config, e.g. AttrDict parsed from yaml. + + Returns: + Solver: Initialized solver object. + """ + config.print_config(cfg) + # TODO(sensen): sanity check for config + output_dir = cfg["Global"]["output_dir"] + epochs = cfg["Global"]["epochs"] + iters_per_epoch = cfg["Global"]["iters_per_epoch"] + save_freq = cfg["Global"]["save_freq"] + eval_during_train = cfg["Global"]["eval_during_train"] + eval_freq = cfg["Global"]["eval_freq"] + + seed = cfg["Global"].get("seed", 42) + rank = dist.get_rank() + misc.set_random_seed(seed + rank) + + model = ppsci.arch.build_model(cfg["Arch"]) + geom = ppsci.geometry.build_geometry(cfg.get("Geometry", None)) + equation = ppsci.equation.build_equation(cfg.get("Equation", None)) + constraint = ppsci.constraint.build_constraint( + cfg["Global"].get("Constraint", None), + equation, + geom, + ) + optimizer, lr_scheduler = ppsci.optimizer.build_optimizer( + cfg["Global"]["Optimizer"], + model + ([eq for eq in equation.values()] if equation is not None else []), + epochs, + iters_per_epoch, + ) + + vdl_writer = None + if cfg["Global"].get("vdl_writer", False): + vdl_writer_path = os.path.join(output_dir, "vdl") + os.makedirs(vdl_writer_path, exist_ok=True) + vdl_writer = vdl.LogWriter(vdl_writer_path) + + log_freq = cfg["Global"].get("log_freq", 10) + device = cfg["Global"].get("device", "gpu") + validator = ppsci.validate.build_validator( + cfg.get("Validator", None), equation, geom + ) + visualizer = ppsci.visualize.build_visualizer(cfg.get("Visualizer", None)) + use_amp = "AMP" in cfg + amp_level = cfg["AMP"].pop("level", "O1").upper() if use_amp else "O0" + + start_eval_epoch = cfg["Global"].get("start_eval_epoch", 1) + update_freq = cfg["Global"].get("update_freq", 1) + pretrained_model_path = cfg["Global"].get("pretrained_model_path", None) + checkpoint_path = cfg["Global"].get("checkpoint_path", None) + compute_metric_by_batch = cfg["Global"].get("compute_metric_by_batch", False) + eval_with_no_grad = cfg["Global"].get("eval_with_no_grad", False) + + return Solver( + model, + constraint, + output_dir, + optimizer, + lr_scheduler, + epochs, + iters_per_epoch, + update_freq, + save_freq, + log_freq, + eval_during_train, + start_eval_epoch, + eval_freq, + seed, + vdl_writer, + device, + equation, + geom, + validator, + visualizer, + use_amp, + amp_level, + pretrained_model_path, + checkpoint_path, + compute_metric_by_batch, + eval_with_no_grad, + ) + + def train(self): + """Training.""" + self.global_step = self.best_metric["epoch"] * self.iters_per_epoch + 1 + + for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1): + self.train_epoch_func(self, epoch_id, self.log_freq) + + # log training summation at end of a epoch + metric_msg = ", ".join( + [self.train_output_info[key].avg_info for key in self.train_output_info] + ) + logger.info(f"[Train][Epoch {epoch_id}/{self.epochs}][Avg] {metric_msg}") + self.train_output_info.clear() + + cur_metric = float("inf") + # evaluate during training + if ( + self.eval_during_train + and epoch_id % self.eval_freq == 0 + and epoch_id >= self.start_eval_epoch + ): + cur_metric = self.eval(epoch_id) + if cur_metric < self.best_metric["metric"]: + self.best_metric["metric"] = cur_metric + self.best_metric["epoch"] = epoch_id + save_load.save_checkpoint( + self.model, + self.optimizer, + self.scaler, + self.best_metric, + self.output_dir, + "best_model", + self.equation, + ) + logger.info( + f"[Eval][Epoch {epoch_id}]" + f"[best metric: {self.best_metric['metric']}]" + ) + logger.scaler("eval_metric", cur_metric, epoch_id, self.vdl_writer) + + # visualize after evaluation + if self.visualizer is not None: + self.visualize(epoch_id) + + # update learning rate by epoch + if self.lr_scheduler is not None and self.lr_scheduler.by_epoch: + self.lr_scheduler.step() + + # save epoch model every save_freq epochs + if self.save_freq > 0 and epoch_id % self.save_freq == 0: + save_load.save_checkpoint( + self.model, + self.optimizer, + self.scaler, + {"metric": cur_metric, "epoch": epoch_id}, + self.output_dir, + f"epoch_{epoch_id}", + self.equation, + ) + + # save the latest model for convenient resume training + save_load.save_checkpoint( + self.model, + self.optimizer, + self.scaler, + {"metric": cur_metric, "epoch": epoch_id}, + self.output_dir, + "latest", + self.equation, + ) + + # close VisualDL + if self.vdl_writer is not None: + self.vdl_writer.close() + + @misc.run_on_eval_mode + def eval(self, epoch_id: int = 0) -> float: + """Evaluation. + + Args: + epoch_id (int, optional): Epoch id. Defaults to 0. + + Returns: + float: The value of the evaluation, used to judge the quality of the model. + """ + # set eval func + self.eval_func = ppsci.solver.eval.eval_func + + result = self.eval_func(self, epoch_id, self.log_freq) + for key in self.eval_output_info: + logger.info(f"[Avg] {self.eval_output_info[key].avg_info}") + # warning: the folowing code is only applicable to cases based on fourcastnet + if "output_" in key: + key_split = key.replace("(Sup_Validator)", "").split(".") + if len(key_split) == 3: + metric_name, output_step, var_name = key_split + step = int(output_step.replace("output_", "")) + logger.scaler( + f"{metric_name}.{var_name}", + self.eval_output_info[key].avg, + step, + self.vdl_writer, + ) + else: + metric_name, output_step = key_split + step = int(output_step.replace("output_", "")) + logger.scaler( + f"{metric_name}", + self.eval_output_info[key].avg, + step, + self.vdl_writer, + ) + # metric_msg = ", ".join( + # [self.eval_output_info[key].avg_info for key in self.eval_output_info] + # ) + # logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}") + self.eval_output_info.clear() + # close VisualDL + if self.vdl_writer is not None: + self.vdl_writer.close() + return result + + @misc.run_on_eval_mode + def visualize(self, epoch_id: int = 0): + """Visualization. + + Args: + epoch_id (int, optional): Epoch id. Defaults to 0. + """ + # set visualize func + self.visu_func = ppsci.solver.visu.visualize_func + + self.visu_func(self, epoch_id) + logger.info(f"[Visualize][Epoch {epoch_id}] Finished visualization") + + @misc.run_on_eval_mode + def predict( + self, + input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]], + expr_dict: Optional[Dict[str, Callable]] = None, + batch_size: int = 64, + no_grad: bool = True, + ) -> Dict[str, paddle.Tensor]: + """Pure prediction using model.forward(...) and expression(optional, if given). + + Args: + input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict. + expr_dict (Optional[Dict[str, Callable]]): Expression dict, which guide to + compute equation variable with callable function. Defaults to None. + batch_size (int, optional): Predicting by batch size. Defaults to 64. + no_grad (bool): Whether set stop_gradient=True for entire prediction, mainly + for memory-efficiency. Defaults to True. + Returns: + Dict[str, paddle.Tensor]: Prediction in dict. + """ + num_samples = len(next(iter(input_dict.values()))) + num_pad = (self.world_size - num_samples % self.world_size) % self.world_size + # pad with last element if `num_samples` is not divisible by `world_size` + # ensuring every device get same number of data. + if num_pad > 0: + for k, v in input_dict.items(): + repeat_times = (num_pad, *(1 for _ in range(v.ndim - 1))) + input_dict[k] = paddle.concat( + ( + v, + paddle.tile(v[num_samples - 1 : num_samples], repeat_times), + ), + ) + + num_samples_pad = num_samples + num_pad + local_num_samples_pad = num_samples_pad // self.world_size + local_input_dict = ( + {k: v[self.rank :: self.world_size] for k, v in input_dict.items()} + if self.world_size > 1 + else input_dict + ) + local_batch_num = (local_num_samples_pad + (batch_size - 1)) // batch_size + pred_dict = misc.Prettydefaultdict(list) + with self.no_grad_context_manager(no_grad), self.no_sync_context_manager( + self.world_size > 1, self.model + ): + for batch_id in range(local_batch_num): + batch_input_dict = {} + st = batch_id * batch_size + ed = min(local_num_samples_pad, (batch_id + 1) * batch_size) + + # prepare batch input dict + for key in local_input_dict: + if not paddle.is_tensor(local_input_dict[key]): + batch_input_dict[key] = paddle.to_tensor( + local_input_dict[key][st:ed], paddle.get_default_dtype() + ) + else: + batch_input_dict[key] = local_input_dict[key][st:ed] + batch_input_dict[key].stop_gradient = no_grad + + # forward + with self.autocast_context_manager(self.use_amp, self.amp_level): + batch_output_dict = self.forward_helper.visu_forward( + expr_dict, batch_input_dict, self.model + ) + + # collect batch data + for key, batch_output in batch_output_dict.items(): + pred_dict[key].append(batch_output.detach()) + + # concatenate local predictions + pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()} + + if self.world_size > 1: + # gather global predictions from all devices if world_size > 1 + pred_dict = { + key: misc.all_gather(value) for key, value in pred_dict.items() + } + + # rearange predictions as the same order of input_dict according to inverse + # permutation, then discard predictions of padding data at the end + perm = np.arange(num_samples_pad, dtype="int64") + perm = np.concatenate( + [perm[rank :: self.world_size] for rank in range(self.world_size)], + axis=0, + ) + perm_inv = np.empty_like(perm) + perm_inv[perm] = np.arange(num_samples_pad, dtype="int64") + perm_inv = paddle.to_tensor(perm_inv) + pred_dict = { + key: value[perm_inv][:num_samples] + for key, value in pred_dict.items() + } + + return pred_dict + + @misc.run_on_eval_mode + def export(self): + """Export to inference model.""" + raise NotImplementedError("model export is not supported yet.") + + def autocast_context_manager( + self, enable: bool, level: Literal["O0", "O1", "O2"] = "O1" + ) -> contextlib.AbstractContextManager: + """Smart autocast context manager for Auto Mix Precision. + + Args: + enable (bool): Enable autocast. + level (Literal["O0", "O1", "O2"]): Autocast level. + + Returns: + contextlib.AbstractContextManager: Smart autocast context manager. + """ + if enable: + ctx_manager = amp.auto_cast(level=level) + else: + ctx_manager = ( + contextlib.nullcontext() + if sys.version_info >= (3, 7) + else contextlib.suppress() + ) + return ctx_manager + + def no_grad_context_manager( + self, enable: bool + ) -> contextlib.AbstractContextManager: + """Smart no_grad context manager. + + Args: + enable (bool): Enable no_grad. + + Returns: + contextlib.AbstractContextManager: Smart no_grad context manager. + """ + if enable: + ctx_manager = paddle.no_grad() + else: + ctx_manager = ( + contextlib.nullcontext() + if sys.version_info >= (3, 7) + else contextlib.suppress() + ) + return ctx_manager + + def no_sync_context_manager( + self, + enable: bool, + ddp_model: paddle.DataParallel, + ) -> contextlib.AbstractContextManager: + """Smart no_sync context manager for given model. + NOTE: Only `paddle.DataParallel` object has `no_sync` interface. + + Args: + enable (bool): Enable no_sync. + + Returns: + contextlib.AbstractContextManager: Smart no_sync context manager. + """ + if enable: + if not isinstance(ddp_model, paddle.DataParallel): + raise TypeError( + "no_sync interface is only for model with type paddle.DataParallel, " + f"but got type {misc.typename(ddp_model)}" + ) + ctx_manager = ddp_model.no_sync() + else: + ctx_manager = ( + contextlib.nullcontext() + if sys.version_info >= (3, 7) + else contextlib.suppress() + ) + return ctx_manager diff --git a/jointContribution/yinglong/ppsci/solver/train.py b/jointContribution/yinglong/ppsci/solver/train.py new file mode 100644 index 0000000000..4f6336e289 --- /dev/null +++ b/jointContribution/yinglong/ppsci/solver/train.py @@ -0,0 +1,224 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import TYPE_CHECKING + +from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu + +from ppsci.solver import printer +from ppsci.utils import misc +from ppsci.utils import profiler + +if TYPE_CHECKING: + from ppsci import solver + + +def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int): + """Train program for one epoch + + Args: + solver (solver.Solver): Main solver. + epoch_id (int): Epoch id. + log_freq (int): Log training information every `log_freq` steps. + """ + batch_tic = time.perf_counter() + + for iter_id in range(1, solver.iters_per_epoch + 1): + total_loss = 0 + loss_dict = misc.Prettydefaultdict(float) + loss_dict["loss"] = 0.0 + total_batch_size = 0 + reader_cost = 0 + batch_cost = 0 + reader_tic = time.perf_counter() + + input_dicts = [] + label_dicts = [] + weight_dicts = [] + for _, _constraint in solver.constraint.items(): + try: + input_dict, label_dict, weight_dict,input_time = next(_constraint.data_iter) + except StopIteration: + _constraint.data_iter = iter(_constraint.data_loader) + input_dict, label_dict, weight_dict,input_time = next(_constraint.data_iter) + # profile code below + # profiler.add_profiler_step(solver.cfg["profiler_options"]) + if iter_id == 5: + # 5 step for warmup + for key in solver.train_time_info: + solver.train_time_info[key].reset() + reader_cost += time.perf_counter() - reader_tic + for v in input_dict.values(): + v.stop_gradient = False + + # gather each constraint's input, label, weight to a list + input_dicts.append(input_dict) + label_dicts.append(label_dict) + weight_dicts.append(weight_dict) + total_batch_size += next(iter(input_dict.values())).shape[0] + reader_tic = time.perf_counter() + + with solver.no_sync_context_manager(solver.world_size > 1, solver.model): + # forward for every constraint, including model and equation expression + with solver.autocast_context_manager(solver.use_amp, solver.amp_level): + constraint_losses = solver.forward_helper.train_forward( + [ + _constraint.output_expr + for _constraint in solver.constraint.values() + ], + input_dicts, + solver.model, + solver.constraint, + label_dicts, + weight_dicts, + input_time, + + ) + # accumulate all losses + for i, _constraint in enumerate(solver.constraint.values()): + total_loss += constraint_losses[i] + loss_dict[_constraint.name] += ( + float(constraint_losses[i]) / solver.update_freq + ) + if solver.update_freq > 1: + total_loss = total_loss / solver.update_freq + loss_dict["loss"] = float(total_loss) + + # backward + if solver.use_amp: + total_loss_scaled = solver.scaler.scale(total_loss) + total_loss_scaled.backward() + else: + total_loss.backward() + + # update parameters + if iter_id % solver.update_freq == 0 or iter_id == solver.iters_per_epoch: + if solver.world_size > 1: + # fuse + allreduce manually before optimization if use DDP + no_sync + # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 + hpu.fused_allreduce_gradients(list(solver.model.parameters()), None) + if solver.use_amp: + solver.scaler.minimize(solver.optimizer, total_loss_scaled) + else: + solver.optimizer.step() + solver.optimizer.clear_grad() + + # update learning rate by step + if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch: + solver.lr_scheduler.step() + + batch_cost += time.perf_counter() - batch_tic + + # update and log training information + solver.global_step += 1 + solver.train_time_info["reader_cost"].update(reader_cost) + solver.train_time_info["batch_cost"].update(batch_cost) + printer.update_train_loss(solver, loss_dict, total_batch_size) + if iter_id == 1 or iter_id % log_freq == 0: + printer.log_train_info(solver, total_batch_size, epoch_id, iter_id) + + batch_tic = time.perf_counter() + + +def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int): + """Train function for one epoch with L-BFGS optimizer. + + Args: + solver (solver.Solver): Main solver. + epoch_id (int): Epoch id. + log_freq (int): Log training information every `log_freq` steps. + """ + batch_tic = time.perf_counter() + + for iter_id in range(1, solver.iters_per_epoch + 1): + loss_dict = misc.Prettydefaultdict(float) + loss_dict["loss"] = 0.0 + total_batch_size = 0 + reader_cost = 0 + batch_cost = 0 + reader_tic = time.perf_counter() + + input_dicts = [] + label_dicts = [] + weight_dicts = [] + for _, _constraint in solver.constraint.items(): + input_dict, label_dict, weight_dict = next(_constraint.data_iter) + reader_cost += time.perf_counter() - reader_tic + for v in input_dict.values(): + v.stop_gradient = False + + # gather all constraint data into list + input_dicts.append(input_dict) + label_dicts.append(label_dict) + weight_dicts.append(weight_dict) + total_batch_size += next(iter(input_dict.values())).shape[0] + reader_tic = time.perf_counter() + + def closure(): + """Forward-backward closure function for LBFGS optimizer. + + Returns: + Tensor: Computed loss. + """ + total_loss = 0 + with solver.no_sync_context_manager(solver.world_size > 1, solver.model): + with solver.autocast_context_manager(solver.use_amp, solver.amp_level): + # forward for every constraint, including model and equation expression + constraint_losses = solver.forward_helper.train_forward( + [ + _constraint.output_expr + for _constraint in solver.constraint.values() + ], + input_dicts, + solver.model, + solver.constraint, + label_dicts, + weight_dicts, + ) + # accumulate all losses + for i, _constraint in enumerate(solver.constraint.values()): + total_loss += constraint_losses[i] + loss_dict[_constraint.name] = float(constraint_losses[i]) + loss_dict["loss"] = float(total_loss) + + # backward + solver.optimizer.clear_grad() + total_loss.backward() + + if solver.world_size > 1: + # fuse + allreduce manually before optimization if use DDP model + # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 + hpu.fused_allreduce_gradients(list(solver.model.parameters()), None) + + return total_loss + + # update parameters + solver.optimizer.step(closure) + + # update learning rate by step + if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch: + solver.lr_scheduler.step() + + batch_cost += time.perf_counter() - batch_tic + + # update and log training information + solver.global_step += 1 + solver.train_time_info["reader_cost"].update(reader_cost) + solver.train_time_info["batch_cost"].update(batch_cost) + printer.update_train_loss(solver, loss_dict, total_batch_size) + if iter_id == 1 or iter_id % log_freq == 0: + printer.log_train_info(solver, total_batch_size, epoch_id, iter_id) + + batch_tic = time.perf_counter() diff --git a/jointContribution/yinglong/ppsci/solver/visu.py b/jointContribution/yinglong/ppsci/solver/visu.py new file mode 100644 index 0000000000..373f6e9992 --- /dev/null +++ b/jointContribution/yinglong/ppsci/solver/visu.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +from typing import TYPE_CHECKING + +import paddle + +from ppsci.utils import misc + +if TYPE_CHECKING: + from ppsci import solver + + +def visualize_func(solver: "solver.Solver", epoch_id: int): + """Visualization program + + Args: + solver (solver.Solver): Main Solver. + epoch_id (int): Epoch id. + """ + for _, _visualizer in solver.visualizer.items(): + all_input = misc.Prettydefaultdict(list) + all_output = misc.Prettydefaultdict(list) + + input_dict = _visualizer.input_dict + batch_size = _visualizer.batch_size + num_samples = len(next(iter(input_dict.values()))) + batch_num = (num_samples + (batch_size - 1)) // batch_size + + for batch_id in range(batch_num): + batch_input_dict = {} + st = batch_id * batch_size + ed = min(num_samples, (batch_id + 1) * batch_size) + + # prepare batch input dict + for key in input_dict: + if not paddle.is_tensor(input_dict[key]): + batch_input_dict[key] = paddle.to_tensor( + input_dict[key][st:ed], paddle.get_default_dtype() + ) + else: + batch_input_dict[key] = input_dict[key][st:ed] + batch_input_dict[key].stop_gradient = False + + # forward + with solver.no_grad_context_manager(solver.eval_with_no_grad): + batch_output_dict = solver.forward_helper.visu_forward( + _visualizer.output_expr, batch_input_dict, solver.model + ) + + # collect batch data + for key, batch_input in batch_input_dict.items(): + all_input[key].append( + batch_input.detach() + if solver.world_size == 1 + else misc.all_gather(batch_input.detach()) + ) + for key, batch_output in batch_output_dict.items(): + all_output[key].append( + batch_output.detach() + if solver.world_size == 1 + else misc.all_gather(batch_output.detach()) + ) + + # concate all data + for key in all_input: + all_input[key] = paddle.concat(all_input[key]) + for key in all_output: + all_output[key] = paddle.concat(all_output[key]) + + # save visualization + if solver.rank == 0: + visual_dir = osp.join(solver.output_dir, "visual", f"epoch_{epoch_id}") + os.makedirs(visual_dir, exist_ok=True) + _visualizer.save( + osp.join(visual_dir, _visualizer.prefix), {**all_input, **all_output} + ) diff --git a/jointContribution/yinglong/ppsci/utils/__init__.py b/jointContribution/yinglong/ppsci/utils/__init__.py new file mode 100644 index 0000000000..1a6f0b05bc --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppsci.utils import initializer +from ppsci.utils import logger +from ppsci.utils import misc +from ppsci.utils import reader +from ppsci.utils.checker import dynamic_import_to_globals +from ppsci.utils.checker import run_check +from ppsci.utils.config import AttrDict +from ppsci.utils.expression import ExpressionSolver +from ppsci.utils.misc import AverageMeter +from ppsci.utils.misc import set_random_seed +from ppsci.utils.reader import load_csv_file +from ppsci.utils.reader import load_mat_file +from ppsci.utils.reader import load_vtk_file +from ppsci.utils.reader import load_vtk_with_time_file +from ppsci.utils.save_load import load_checkpoint +from ppsci.utils.save_load import load_pretrain +from ppsci.utils.save_load import save_checkpoint + +__all__ = [ + "initializer", + "logger", + "misc", + "reader", + "load_csv_file", + "load_mat_file", + "load_vtk_file", + "load_vtk_with_time_file", + "dynamic_import_to_globals", + "run_check", + "AttrDict", + "ExpressionSolver", + "AverageMeter", + "set_random_seed", + "load_checkpoint", + "load_pretrain", + "save_checkpoint", +] diff --git a/jointContribution/yinglong/ppsci/utils/checker.py b/jointContribution/yinglong/ppsci/utils/checker.py new file mode 100644 index 0000000000..c8777d5e72 --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/checker.py @@ -0,0 +1,150 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import traceback +from typing import Dict +from typing import Tuple +from typing import Union + +import paddle + +from ppsci.utils import logger + + +def run_check() -> None: + """Check whether PaddleScience is installed correctly and running successfully on + your system. + + Examples: + >>> import ppsci + >>> ppsci.utils.run_check() # doctest: +SKIP + """ + # test demo code below. + import logging + + import ppsci + + try: + ppsci.utils.set_random_seed(42) + ppsci.utils.logger.init_logger() + model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16, "tanh") + + equation = {"NavierStokes": ppsci.equation.NavierStokes(0.01, 1.0, 2, False)} + + geom = {"rect": ppsci.geometry.Rectangle((-0.05, -0.05), (0.05, 0.05))} + + ITERS_PER_EPOCH = 5 + train_dataloader_cfg = { + "dataset": "IterableNamedArrayDataset", + "iters_per_epoch": ITERS_PER_EPOCH, + } + + NPOINT_PDE = 8**2 + pde_constraint = ppsci.constraint.InteriorConstraint( + equation["NavierStokes"].equations, + {"continuity": 0, "momentum_x": 0, "momentum_y": 0}, + geom["rect"], + {**train_dataloader_cfg, "batch_size": NPOINT_PDE}, + ppsci.loss.MSELoss("sum"), + evenly=True, + weight_dict={ + "continuity": 0.0001, + "momentum_x": 0.0001, + "momentum_y": 0.0001, + }, + name="EQ", + ) + constraint = {pde_constraint.name: pde_constraint} + + residual_validator = ppsci.validate.GeometryValidator( + equation["NavierStokes"].equations, + {"continuity": 0, "momentum_x": 0, "momentum_y": 0}, + geom["rect"], + { + "dataset": "NamedArrayDataset", + "total_size": 8**2, + "batch_size": 32, + "sampler": {"name": "BatchSampler"}, + }, + ppsci.loss.MSELoss("sum"), + evenly=True, + metric={"MSE": ppsci.metric.MSE(False)}, + name="Residual", + ) + validator = {residual_validator.name: residual_validator} + + EPOCHS = 2 + optimizer = ppsci.optimizer.Adam(0.001)((model,)) + solver = ppsci.solver.Solver( + model, + constraint, + None, + optimizer, + None, + EPOCHS, + ITERS_PER_EPOCH, + device=paddle.device.get_device(), + equation=equation, + validator=validator, + ) + solver.train() + solver.eval(EPOCHS) + except Exception as e: + traceback.print_exc() + logging.warning( + f"PaddleScience meets some problem with \n {repr(e)} \nplease check whether " + "Paddle's version and PaddleScience's version are both correct." + ) + else: + print("PaddleScience is installed successfully.✨ 🍰 ✨") + + +def dynamic_import_to_globals( + names: Union[str, Tuple[str, ...]], alias: Dict[str, str] = None +) -> bool: + """Import module and add it to globals() by given names dynamically. + + Args: + names (Union[str, Tuple[str, ...]]): Module name or list of module names. + alias (Dict[str, str]): Alias name of module when imported into globals(). + + Returns: + bool: Whether given names all exist. + """ + if isinstance(names, str): + names = (names,) + + if alias is None: + alias = {} + + for name in names: + # find module in environment by it's name and alias(if given) + module_spec = importlib.util.find_spec(name) + if module_spec is None and name in alias: + module_spec = importlib.util.find_spec(alias[name]) + + # log error and return False if module do not exist + if not module_spec: + logger.error(f"Module {name} should be installed first.") + return False + + # module exist, add to globals() if not in globals() + add_name = name + if add_name in alias: + add_name = alias[add_name] + if add_name not in globals(): + globals()[add_name] = importlib.import_module(name) + + return True diff --git a/jointContribution/yinglong/ppsci/utils/config.py b/jointContribution/yinglong/ppsci/utils/config.py new file mode 100644 index 0000000000..595685152a --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/config.py @@ -0,0 +1,211 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import os + +import yaml +from paddle import static + +from ppsci.utils import logger + +__all__ = ["get_config", "replace_shape_with_inputspec_", "AttrDict"] + + +class AttrDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value + + def __deepcopy__(self, content): + return AttrDict(copy.deepcopy(dict(self))) + + +def create_attr_dict(yaml_config): + from ast import literal_eval + + for key, value in yaml_config.items(): + if isinstance(value, dict): + yaml_config[key] = value = AttrDict(value) + if isinstance(value, str): + try: + value = literal_eval(value) + except BaseException: + pass + if isinstance(value, AttrDict): + create_attr_dict(yaml_config[key]) + else: + yaml_config[key] = value + + +def parse_config(cfg_file): + """Load a config file into AttrDict""" + with open(cfg_file, "r") as fopen: + yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader)) + create_attr_dict(yaml_config) + return yaml_config + + +def print_dict(d, delimiter=0): + """ + Recursively visualize a dict and + indenting acrrording by the relationship of keys. + """ + placeholder = "-" * 60 + for k, v in d.items(): + if isinstance(v, dict): + logger.info(f"{delimiter * ' '}{k} : ") + print_dict(v, delimiter + 4) + elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): + logger.info(f"{delimiter * ' '}{k} : ") + for value in v: + print_dict(value, delimiter + 2) + else: + logger.info(f"{delimiter * ' '}{k} : {v}") + + if k[0].isupper() and delimiter == 0: + logger.info(placeholder) + + +def print_config(config): + """ + Visualize configs + Arguments: + config: configs + """ + logger.advertise() + print_dict(config) + + +def override(dl, ks, v): + """ + Recursively replace dict of list + Args: + dl(dict or list): dict or list to be replaced + ks(list): list of keys + v(str): value to be replaced + """ + + def str2num(v): + try: + return eval(v) + except Exception: + return v + + if not isinstance(dl, (list, dict)): + raise ValueError(f"{dl} should be a list or a dict") + if len(ks) <= 0: + raise ValueError("lenght of keys should be larger than 0") + + if isinstance(dl, list): + k = str2num(ks[0]) + if len(ks) == 1: + if k >= len(dl): + raise ValueError(f"index({k}) out of range({dl})") + dl[k] = str2num(v) + else: + override(dl[k], ks[1:], v) + else: + if len(ks) == 1: + # assert ks[0] in dl, (f"{ks[0]} is not exist in {dl}") + if not ks[0] in dl: + print(f"A new field ({ks[0]}) detected!") + dl[ks[0]] = str2num(v) + else: + if ks[0] not in dl.keys(): + dl[ks[0]] = {} + print(f"A new Series field ({ks[0]}) detected!") + override(dl[ks[0]], ks[1:], v) + + +def override_config(config, options=None): + """ + Recursively override the config + Args: + config(dict): dict to be replaced + options(list): list of pairs(key0.key1.idx.key2=value) + such as: [ + "topk=2", + "VALID.transforms.1.ResizeImage.resize_short=300" + ] + Returns: + config(dict): replaced config + """ + if options is not None: + for opt in options: + assert isinstance(opt, str), f"option({opt}) should be a str" + assert ( + "=" in opt + ), f"option({opt}) should contain a = to distinguish between key and value" + pair = opt.split("=") + assert len(pair) == 2, "there can be only a = in the option" + key, value = pair + keys = key.split(".") + override(config, keys, value) + return config + + +def get_config(fname, overrides=None, show=False): + """ + Read config from file + """ + if not os.path.exists(fname): + raise FileNotFoundError(f"config file({fname}) is not exist") + config = parse_config(fname) + override_config(config, overrides) + if show: + print_config(config) + return config + + +def parse_args(): + parser = argparse.ArgumentParser("paddlescience running script") + parser.add_argument("-e", "--epochs", type=int, help="training epochs") + parser.add_argument("-o", "--output_dir", type=str, help="output directory") + parser.add_argument("-n", "--num_timestamps", type=int, help="num_timestamps") + parser.add_argument( + "--to_static", + action="store_true", + help="whether enable to_static for forward computation", + ) + + args = parser.parse_args() + return args + + +def _is_num_seq(seq): + # whether seq is all int number(it is a shape) + return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq) + + +def replace_shape_with_inputspec_(node: AttrDict): + if _is_num_seq(node): + return True + + if isinstance(node, dict): + for key in node: + if replace_shape_with_inputspec_(node[key]): + node[key] = static.InputSpec(node[key]) + elif isinstance(node, list): + for i in range(len(node)): + if replace_shape_with_inputspec_(node[i]): + node[i] = static.InputSpec(node[i]) + + return False diff --git a/jointContribution/yinglong/ppsci/utils/download.py b/jointContribution/yinglong/ppsci/utils/download.py new file mode 100644 index 0000000000..4c0faf9a5e --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/download.py @@ -0,0 +1,283 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +import os.path as osp +import shutil +import tarfile +import time +import zipfile + +import requests +import tqdm + +from ppsci.utils import logger + +__all__ = ["get_weights_path_from_url"] + +WEIGHTS_HOME = osp.expanduser("~/.paddlesci/weights") + +DOWNLOAD_RETRY_LIMIT = 3 + + +def is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith("http://") or path.startswith("https://") + + +def get_weights_path_from_url(url, md5sum=None): + """Get weights path from WEIGHT_HOME, if not exists, + download it from url. + + Args: + url (str): download url + md5sum (str): md5 sum of download package + + Returns: + str: a local path to save downloaded weights. + """ + path = get_path_from_url(url, WEIGHTS_HOME, md5sum) + return path + + +def _map_path(url, root_dir): + # parse path after download under root_dir + fname = osp.split(url)[-1] + fpath = fname + return osp.join(root_dir, fpath) + + +def get_path_from_url(url, root_dir, md5sum=None, check_exist=True, decompress=True): + """Download from given url to root_dir. + if file or directory specified by url is exists under + root_dir, return the path directly, otherwise download + from url and decompress it, return the path. + + Args: + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + + Returns: + str: a local path to save downloaded models & weights & datasets. + """ + if not is_url(url): + raise ValueError(f"Given url({url}) is not valid") + # parse path after download to decompress under root_dir + fullpath = _map_path(url, root_dir) + # Mainly used to solve the problem of downloading data from different + # machines in the case of multiple machines. Different nodes will download + # data, and the same node will only download data once. + rank_id_curr_node = int(os.environ.get("PADDLE_RANK_IN_NODE", 0)) + + if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): + logger.info(f"Found {fullpath}") + else: + if rank_id_curr_node == 0: + fullpath = _download(url, root_dir, md5sum) + else: + while not os.path.exists(fullpath): + time.sleep(1) + + if rank_id_curr_node == 0: + if decompress and ( + tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath) + ): + fullpath = _decompress(fullpath) + + return fullpath + + +def _download(url, path, md5sum=None): + """ + Download from url, save to path. + + url (str): download url + path (str): download to given path + """ + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + retry_cnt = 0 + + while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError(f"Download from {url} failed. " "Retry limit reached") + + logger.info(f"Downloading {fname} from {url}") + + try: + req = requests.get(url, stream=True) + except Exception as e: # requests.exceptions.ConnectionError + logger.info( + f"Downloading {fname} from {url} failed {retry_cnt + 1} times with exception {str(e)}" + ) + time.sleep(1) + continue + + if req.status_code != 200: + raise RuntimeError( + f"Downloading from {url} failed with code " f"{req.status_code}!" + ) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get("content-length") + with open(tmp_fullname, "wb") as f: + if total_size: + with tqdm.tqdm(total=(int(total_size) + 1023) // 1024) as pbar: + for chunk in req.iter_content(chunk_size=1024): + f.write(chunk) + pbar.update(1) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + logger.info(f"File {fullname} md5 checking...") + md5 = hashlib.md5() + with open(fullname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.info( + f"File {fullname} md5 check failed, {calc_md5sum}(calc) != " + f"{md5sum}(base)" + ) + return False + return True + + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + logger.info(f"Decompressing {fname}...") + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + + if tarfile.is_tarfile(fname): + uncompressed_path = _uncompress_file_tar(fname) + elif zipfile.is_zipfile(fname): + uncompressed_path = _uncompress_file_zip(fname) + else: + raise TypeError(f"Unsupport compress file type {fname}") + + return uncompressed_path + + +def _uncompress_file_zip(filepath): + with zipfile.ZipFile(filepath, "r") as files: + file_list = files.namelist() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + return uncompressed_path + + +def _uncompress_file_tar(filepath, mode="r:*"): + with tarfile.open(filepath, mode) as files: + file_list = files.getnames() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + return uncompressed_path + + +def _is_a_single_file(file_list): + if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + return True + return False + + +def _is_a_single_dir(file_list): + new_file_list = [] + for file_path in file_list: + if "/" in file_path: + file_path = file_path.replace("/", os.sep) + elif "\\" in file_path: + file_path = file_path.replace("\\", os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: + return False + return True diff --git a/jointContribution/yinglong/ppsci/utils/expression.py b/jointContribution/yinglong/ppsci/utils/expression.py new file mode 100644 index 0000000000..74309b9305 --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/expression.py @@ -0,0 +1,192 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Tuple + +from paddle import jit +from paddle import nn + +if TYPE_CHECKING: + import paddle + + from ppsci import constraint + from ppsci import validate + +# from ppsci.autodiff import clear + + +class ExpressionSolver(nn.Layer): + """Expression computing helper, which compute named result according to corresponding + function and related inputs. + + Examples: + >>> import ppsci + >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128) + >>> expr_solver = ExpressionSolver() + """ + + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Use train_forward/eval_forward/visu_forward instead of forward." + ) + + @jit.to_static + def train_forward( + self, + expr_dicts: Tuple[Dict[str, Callable], ...], + input_dicts: Tuple[Dict[str, "paddle.Tensor"], ...], + model: nn.Layer, + constraint: Dict[str, "constraint.Constraint"], + label_dicts: Tuple[Dict[str, "paddle.Tensor"], ...], + weight_dicts: Tuple[Dict[str, "paddle.Tensor"], ...], + input_time, + ) -> Tuple["paddle.Tensor", ...]: + """Forward computation for training, including model forward and equation + forward. + + Args: + expr_dicts (Tuple[Dict[str, Callable], ...]): Tuple of expression dicts. + input_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of input dicts. + model (nn.Layer): NN model. + constraint (Dict[str, "constraint.Constraint"]): Constraint dict. + label_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of label dicts. + weight_dicts (Tuple[Dict[str, paddle.Tensor], ...]): Tuple of weight dicts. + + Returns: + Tuple[paddle.Tensor, ...]: Tuple of losses for each constraint. + """ + output_dicts = [] + for i, expr_dict in enumerate(expr_dicts): + # model forward + if callable(next(iter(expr_dict.values()))): + output_dict = model(input_dicts[i], input_time) + + # equation forward + for name, expr in expr_dict.items(): + if name not in label_dicts[i]: + continue + if callable(expr): + output_dict[name] = expr({**output_dict, **input_dicts[i]}) + else: + raise TypeError(f"expr type({type(expr)}) is invalid") + + # put field 'area' into output_dict + if "area" in input_dicts[i]: + output_dict["area"] = input_dicts[i]["area"] + + output_dicts.append(output_dict) + + # clear differentiation cache + # clear() + + # compute loss for each constraint according to its' own output, label and weight + constraint_losses = [] + for i, _constraint in enumerate(constraint.values()): + constraint_loss = _constraint.loss( + output_dicts[i], + label_dicts[i], + weight_dicts[i], + ) + constraint_losses.append(constraint_loss) + return constraint_losses + + @jit.to_static + def eval_forward( + self, + expr_dict: Dict[str, Callable], + input_dict: Dict[str, "paddle.Tensor"], + model: nn.Layer, + validator: "validate.Validator", + label_dict: Dict[str, "paddle.Tensor"], + weight_dict: Dict[str, "paddle.Tensor"], + input_time, + ) -> Tuple[Dict[str, "paddle.Tensor"], "paddle.Tensor"]: + """Forward computation for evaluation, including model forward and equation + forward. + + Args: + expr_dict (Dict[str, Callable]): Expression dict. + input_dict (Dict[str, paddle.Tensor]): Input dict. + model (nn.Layer): NN model. + validator (validate.Validator): Validator. + label_dict (Dict[str, paddle.Tensor]): Label dict. + weight_dict (Dict[str, paddle.Tensor]): Weight dict. + + Returns: + Tuple[Dict[str, paddle.Tensor], paddle.Tensor]: Result dict and loss for + given validator. + """ + # model forward + if callable(next(iter(expr_dict.values()))): + output_dict = model(input_dict, input_time) + + # equation forward + for name, expr in expr_dict.items(): + if name not in label_dict: + continue + if callable(expr): + output_dict[name] = expr({**output_dict, **input_dict}) + else: + raise TypeError(f"expr type({type(expr)}) is invalid") + + # clear differentiation cache + # clear() + + # compute loss for each validator according to its' own output, label and weight + validator_loss = validator.loss( + output_dict, + label_dict, + weight_dict, + ) + return output_dict, validator_loss + + def visu_forward( + self, + expr_dict: Optional[Dict[str, Callable]], + input_dict: Dict[str, "paddle.Tensor"], + model: nn.Layer, + ) -> Dict[str, "paddle.Tensor"]: + """Forward computation for visualization, including model forward and equation + forward. + + Args: + expr_dict (Optional[Dict[str, Callable]]): Expression dict. + input_dict (Dict[str, paddle.Tensor]): Input dict. + model (nn.Layer): NN model. + + Returns: + Dict[str, paddle.Tensor]: Result dict for given expression dict. + """ + # model forward + output_dict = model(input_dict) + + if isinstance(expr_dict, dict): + # equation forward + for name, expr in expr_dict.items(): + if callable(expr): + output_dict[name] = expr({**output_dict, **input_dict}) + else: + raise TypeError(f"expr type({type(expr)}) is invalid") + + # clear differentiation cache + # clear() + + return output_dict diff --git a/jointContribution/yinglong/ppsci/utils/initializer.py b/jointContribution/yinglong/ppsci/utils/initializer.py new file mode 100644 index 0000000000..1054b4e340 --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/initializer.py @@ -0,0 +1,452 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +The initialization method under this module is aligned with pytorch initialization. +If you need to use the initialization method of PaddlePaddle, please refer to +[paddle.nn.initializer](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/nn/initializer) + +This code is based on [torch.nn.init](https://github.com/pytorch/pytorch/blob/main/torch/nn/init.py) +Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file. +""" + +import math + +import numpy as np +import paddle +from paddle import nn +from typing_extensions import Literal + +from ppsci.utils import logger + +__all__ = [ + "uniform_", + "normal_", + "trunc_normal_", + "constant_", + "ones_", + "zeros_", + "xavier_uniform_", + "xavier_normal_", + "kaiming_uniform_", + "kaiming_normal_", + "linear_init_", + "conv_init_", +] + + +def _no_grad_uniform_(tensor, a, b): + with paddle.no_grad(): + tensor.set_value( + paddle.uniform(shape=tensor.shape, dtype=tensor.dtype, min=a, max=b) + ) + return tensor + + +def _no_grad_normal_(tensor, mean=0.0, std=1.0): + with paddle.no_grad(): + tensor.set_value(paddle.normal(mean=mean, std=std, shape=tensor.shape)) + return tensor + + +def _no_grad_trunc_normal_(tensor, mean=0.0, std=1.0, a=2.0, b=2.0): + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning( + f"mean({mean}) is more than 2 std({std}) from [a, b]([{a}, {b}]) in _no_grad_trunc_normal_. " + "The distribution of values may be incorrect." + ) + with paddle.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + _tensor = paddle.uniform( + shape=tensor.shape, dtype=tensor.dtype, min=2 * l - 1, max=2 * u - 1 + ) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + _tensor.erfinv_() + + # Transform to proper mean, std + _tensor = paddle.multiply(_tensor, paddle.to_tensor(std * math.sqrt(2.0))) + _tensor = paddle.add(_tensor, paddle.to_tensor(mean)) + + # Clamp to ensure it"s in the proper range + _tensor = paddle.clip(_tensor, min=a, max=b) + tensor.set_value(_tensor) + return tensor + + +def _no_grad_fill_(tensor, value=0.0): + with paddle.no_grad(): + tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype)) + return tensor + + +def uniform_(tensor: paddle.Tensor, a: float, b: float) -> paddle.Tensor: + """Modify tensor inplace using uniform_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + a (float): min value. + b (float): max value. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.uniform_(param, -1, 1) + """ + return _no_grad_uniform_(tensor, a, b) + + +def normal_( + tensor: paddle.Tensor, mean: float = 0.0, std: float = 1.0 +) -> paddle.Tensor: + """Modify tensor inplace using normal_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + mean (float, optional): mean value. Defaults to 0.0. + std (float, optional): std value. Defaults to 1.0. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.normal_(param, 0, 1) + """ + return _no_grad_normal_(tensor, mean, std) + + +def trunc_normal_( + tensor: paddle.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> paddle.Tensor: + """Modify tensor inplace using trunc_normal_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + mean (float, optional): The mean of the normal distribution. Defaults to 0.0. + std (float, optional): The standard deviation of the normal distribution. Defaults to 1.0. + a (float, optional): The minimum cutoff value. Defaults to -2.0. + b (float, optional): The maximum cutoff value. Defaults to 2.0. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.trunc_normal_(param, 0.0, 1.0) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def constant_(tensor: paddle.Tensor, value: float = 0.0) -> paddle.Tensor: + """Modify tensor inplace using constant_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + value (float, optional): value to fill tensor. Defaults to 0.0. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.constant_(param, 2) + """ + return _no_grad_fill_(tensor, value) + + +def ones_(tensor: paddle.Tensor) -> paddle.Tensor: + """Modify tensor inplace using ones_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.ones_(param) + """ + return _no_grad_fill_(tensor, 1) + + +def zeros_(tensor: paddle.Tensor) -> paddle.Tensor: + """Modify tensor inplace using zeros_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.zeros_(param) + """ + return _no_grad_fill_(tensor, 0) + + +def _calculate_fan_in_and_fan_out(tensor, reverse=False): + """ + Calculate (fan_in, _fan_out) for tensor. + + Args: + tensor (paddle.Tensor): paddle.Tensor. + reverse (bool): tensor data format order, False by default as [fout, fin, ...]. + e.g. : conv.weight [cout, cin, kh, kw] is False; linear.weight [cin, cout] + is True. + + Return: + Tuple[float, float]: (fan_in, fan_out). + """ + if tensor.ndim < 2: + raise ValueError( + f"tensor.ndim should be no less than 2, but got {tensor.ndim}." + ) + + if reverse: + num_input_fmaps, num_output_fmaps = tensor.shape[0], tensor.shape[1] + else: + num_input_fmaps, num_output_fmaps = tensor.shape[1], tensor.shape[0] + + receptive_field_size = 1 + if tensor.ndim > 2: + receptive_field_size = np.prod(tensor.shape[2:]) + + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_( + tensor: paddle.Tensor, gain: float = 1.0, reverse: bool = False +) -> paddle.Tensor: + """Modify tensor inplace using xavier_uniform_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + gain (float, optional): Hyperparameter. Defaults to 1.0. + reverse (bool, optional): Tensor data format order, False by default as + [fout, fin, ...].. Defaults to False. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.xavier_uniform_(param) + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + k = math.sqrt(3.0) * std + return _no_grad_uniform_(tensor, -k, k) + + +def xavier_normal_( + tensor: paddle.Tensor, gain: float = 1.0, reverse: bool = False +) -> paddle.Tensor: + """Modify tensor inplace using xavier_normal_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + gain (float, optional): Hyperparameter. Defaults to 1.0. + reverse (bool, optional): tensor data format order, False by + default as [fout, fin, ...]. Defaults to False. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.xavier_normal_(param) + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + return _no_grad_normal_(tensor, 0, std) + + +# reference: https://pytorch.org/docs/stable/_modules/torch/nn/init.html +def _calculate_correct_fan(tensor, mode, reverse=False): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse) + + return fan_in if mode == "fan_in" else fan_out + + +def _calculate_gain(nonlinearity, param=None): + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError(f"negative_slope {param} not a valid number") + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return 3.0 / 4 + else: + raise ValueError(f"Unsupported nonlinearity {nonlinearity}") + + +def kaiming_uniform_( + tensor: paddle.Tensor, + a: float = 0, + mode: Literal["fan_in", "fan_out"] = "fan_in", + nonlinearity: str = "leaky_relu", + reverse: bool = False, +) -> paddle.Tensor: + """Modify tensor inplace using kaiming_uniform method. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + a (float, optional): The negative slope of the rectifier used after this layer. + Defaults to 0. + mode (Literal["fan_in", "fan_out"], optional): + ["fan_in", "fan_out"]. Defaults to "fan_in". + nonlinearity (str, optional): Nonlinearity method name. Defaults to "leaky_relu". + reverse (bool, optional): tensor data format order, False by default as + [fout, fin, ...].. Defaults to False. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.kaiming_uniform_(param) + """ + fan = _calculate_correct_fan(tensor, mode, reverse) + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + k = math.sqrt(3.0) * std + return _no_grad_uniform_(tensor, -k, k) + + +def kaiming_normal_( + tensor: paddle.Tensor, + a: float = 0, + mode: Literal["fan_in", "fan_out"] = "fan_in", + nonlinearity: str = "leaky_relu", + reverse: bool = False, +) -> paddle.Tensor: + """Modify tensor inplace using kaiming_normal_. + + Args: + tensor (paddle.Tensor): Paddle Tensor. + a (float, optional): The negative slope of the rectifier used after this layer. + Defaults to 0. + mode (Literal["fan_in", "fan_out"], optional): Either + 'fan_in' (default) or 'fan_out'. Defaults to "fan_in". + nonlinearity (str, optional): Nonlinearity method name. Defaults to "leaky_relu". + reverse (bool, optional): Tensor data format order. Defaults to False. + + Returns: + paddle.Tensor: Initialized tensor. + + Examples: + >>> import paddle + >>> import ppsci + >>> param = paddle.empty((128, 256), "float32") + >>> param = ppsci.utils.initializer.kaiming_normal_(param) + """ + fan = _calculate_correct_fan(tensor, mode, reverse) + gain = _calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return _no_grad_normal_(tensor, 0, std) + + +def linear_init_(module: nn.Layer) -> None: + """Initialize module's weight and bias as it is a linear layer. + + Args: + module (nn.Layer): Linear Layer to be initialized. + """ + bound = 1 / math.sqrt(module.weight.shape[0]) + uniform_(module.weight, -bound, bound) + if module.bias is not None: + uniform_(module.bias, -bound, bound) + + +def conv_init_(module: nn.Layer) -> None: + """Initialize module's weight and bias as it is a conv layer. + + Args: + module (nn.Layer): Convolution Layer to be initialized. + """ + bound = 1 / np.sqrt(np.prod(module.weight.shape[1:])) + uniform_(module.weight, -bound, bound) + if module.bias is not None: + uniform_(module.bias, -bound, bound) diff --git a/jointContribution/yinglong/ppsci/utils/logger.py b/jointContribution/yinglong/ppsci/utils/logger.py new file mode 100644 index 0000000000..0f58c3c72a --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/logger.py @@ -0,0 +1,177 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys + +import paddle.distributed as dist + +_logger = None + + +def init_logger(name="ppsci", log_file=None, log_level=logging.INFO): + """Initialize and get a logger by name. + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified a FileHandler will also be added. + + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. Note that only the process of + rank 0 is affected, and other processes will set the level to + "Error" thus be silent most of the time. + Returns: + logging.Logger: The expected logger. + """ + if isinstance(log_level, str): + log_level = getattr(logging, log_level.upper()) + + global _logger + + # solve mutiple init issue when using paddlescience.py and engin.engin + init_flag = False + if _logger is None: + _logger = logging.getLogger(name) + init_flag = True + + formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S" + ) + + stream_handler = logging.StreamHandler(stream=sys.stdout) + stream_handler.setFormatter(formatter) + stream_handler._name = "stream_handler" + + # add stream_handler when _logger dose not contain stream_handler + for i, h in enumerate(_logger.handlers): + if h.get_name() == stream_handler.get_name(): + break + if i == len(_logger.handlers) - 1: + _logger.addHandler(stream_handler) + if init_flag: + _logger.addHandler(stream_handler) + + if log_file is not None and dist.get_rank() == 0: + log_file_folder = os.path.split(log_file)[0] + os.makedirs(log_file_folder, exist_ok=True) + file_handler = logging.FileHandler(log_file, "a") + file_handler.setFormatter(formatter) + file_handler._name = "file_handler" + + # add file_handler when _logger dose not contain same file_handler + for i, h in enumerate(_logger.handlers): + if ( + h.get_name() == file_handler.get_name() + and h.baseFilename == file_handler.baseFilename + ): + break + if i == len(_logger.handlers) - 1: + _logger.addHandler(file_handler) + + if dist.get_rank() == 0: + _logger.setLevel(log_level) + else: + _logger.setLevel(logging.ERROR) + _logger.propagate = False + + +def set_log_level(log_level): + """Set log level.""" + if dist.get_rank() == 0: + _logger.setLevel(log_level) + else: + _logger.setLevel(logging.ERROR) + + +def log_at_trainer0(log): + """ + Logs will print multi-times when calling Fleet API. + Only display single log and ignore the others. + """ + + def wrapper(fmt, *args): + if dist.get_rank() == 0: + log(fmt, *args) + + return wrapper + + +@log_at_trainer0 +def info(fmt, *args): + _logger.info(fmt, *args) + + +@log_at_trainer0 +def debug(fmt, *args): + _logger.debug(fmt, *args) + + +@log_at_trainer0 +def warning(fmt, *args): + _logger.warning(fmt, *args) + + +@log_at_trainer0 +def error(fmt, *args): + _logger.error(fmt, *args) + + +def scaler(name, value, step, writer): + """ + This function will draw a scalar curve generated by the visualdl. + Usage: Install visualdl: pip3 install visualdl==2.0.0b4 + and then: + visualdl --logdir ./scalar --host 0.0.0.0 --port 8830 + to preview loss corve in real time. + """ + if writer is None: + return + writer.add_scalar(tag=name, step=step, value=value) + + +def advertise(): + """ + Show the advertising message like the following: + + =========================================================== + == PaddleScience is powered by PaddlePaddle ! == + =========================================================== + == == + == For more info please go to the following website. == + == == + == https://github.com/PaddlePaddle/PaddleScience == + =========================================================== + + """ + _copyright = "PaddleScience is powered by PaddlePaddle !" + ad = "For more info please go to the following website." + website = "https://github.com/PaddlePaddle/PaddleScience" + AD_LEN = 6 + len(max([_copyright, ad, website], key=len)) + + info( + "\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format( + "=" * (AD_LEN + 4), + "=={}==".format(_copyright.center(AD_LEN)), + "=" * (AD_LEN + 4), + "=={}==".format(" " * AD_LEN), + "=={}==".format(ad.center(AD_LEN)), + "=={}==".format(" " * AD_LEN), + "=={}==".format(website.center(AD_LEN)), + "=" * (AD_LEN + 4), + ) + ) diff --git a/jointContribution/yinglong/ppsci/utils/misc.py b/jointContribution/yinglong/ppsci/utils/misc.py new file mode 100644 index 0000000000..27a7a9ae2c --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/misc.py @@ -0,0 +1,267 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import functools +import random +from typing import Callable +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union + +import numpy as np +import paddle + +__all__ = [ + "all_gather", + "AverageMeter", + "PrettyOrderedDict", + "Prettydefaultdict", + "concat_dict_list", + "convert_to_array", + "convert_to_dict", + "stack_dict_list", + "combine_array_with_time", + "set_random_seed", + "run_on_eval_mode", +] + + +class AverageMeter: + """ + Computes and stores the average and current value + Code was based on https://github.com/pytorch/examples/blob/master/imagenet/main.py + """ + + def __init__(self, name="", fmt="f", postfix="", need_avg=True): + self.name = name + self.fmt = fmt + self.postfix = postfix + self.need_avg = need_avg + self.reset() + + def reset(self): + """Reset""" + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + """Update""" + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + @property + def avg_info(self): + if isinstance(self.avg, paddle.Tensor): + self.avg = float(self.avg) + return f"{self.name}: {self.avg:.5f}" + + @property + def total(self): + return f"{self.name}_sum: {self.sum:{self.fmt}}{self.postfix}" + + @property + def total_minute(self): + return f"{self.name} {self.sum / 60:{self.fmt}}{self.postfix} min" + + @property + def mean(self): + return ( + f"{self.name}: {self.avg:{self.fmt}}{self.postfix}" if self.need_avg else "" + ) + + @property + def value(self): + return f"{self.name}: {self.val:{self.fmt}}{self.postfix}" + + +class PrettyOrderedDict(collections.OrderedDict): + def __str__(self): + return "".join([str((k, v)) for k, v in self.items()]) + + +class Prettydefaultdict(collections.defaultdict): + def __str__(self): + return "".join([str((k, v)) for k, v in self.items()]) + + +def convert_to_dict(array: np.ndarray, keys: Tuple[str, ...]) -> Dict[str, np.ndarray]: + """Split given array into single channel array at axis -1 in order of given keys. + + Args: + array (np.ndarray): Array to be splited. + keys (Tuple[str, ...]):Keys used in split. + + Returns: + Dict[str, np.ndarray]: Splited dict. + """ + if array.shape[-1] != len(keys): + raise ValueError( + f"dim of array({array.shape[-1]}) must equal to " f"len(keys)({len(keys)})" + ) + + split_array = np.split(array, len(keys), axis=-1) + return {key: split_array[i] for i, key in enumerate(keys)} + + +def all_gather( + tensor: paddle.Tensor, concat: bool = True, axis: int = 0 +) -> Union[paddle.Tensor, List[paddle.Tensor]]: + """Gather tensor from all devices, concatenate them along given axis if specified. + + Args: + tensor (paddle.Tensor): Tensor to be gathered from all GPUs. + concat (bool, optional): Whether to concatenate gathered Tensors. Defaults to True. + axis (int, optional): Axis which concatenated along. Defaults to 0. + + Returns: + Union[paddle.Tensor, List[paddle.Tensor]]: Gathered Tensors + """ + result = [] + paddle.distributed.all_gather(result, tensor) + if concat: + return paddle.concat(result, axis) + return result + + +def convert_to_array(dict_: Dict[str, np.ndarray], keys: Tuple[str, ...]) -> np.ndarray: + """Concatenate arrays in axis -1 in order of given keys. + + Args: + dict_ (Dict[str, np.ndarray]): Dict contains arrays. + keys (Tuple[str, ...]): Concatenate keys used in concatenation. + + Returns: + np.ndarray: Concatenated array. + """ + return np.concatenate([dict_[key] for key in keys], axis=-1) + + +def concat_dict_list( + dict_list: Tuple[Dict[str, np.ndarray], ...] +) -> Dict[str, np.ndarray]: + """Concatenate arrays in tuple of dicts at axis 0. + + Args: + dict_list (Tuple[Dict[str, np.ndarray], ...]): Tuple of dicts. + + Returns: + Dict[str, np.ndarray]: A dict with concatenated arrays for each key. + """ + ret = {} + for key in dict_list[0].keys(): + ret[key] = np.concatenate([_dict[key] for _dict in dict_list], axis=0) + return ret + + +def stack_dict_list( + dict_list: Tuple[Dict[str, np.ndarray], ...] +) -> Dict[str, np.ndarray]: + """Stack arrays in tuple of dicts at axis 0. + + Args: + dict_list (Tuple[Dict[str, np.ndarray], ...]): Tuple of dicts. + + Returns: + Dict[str, np.ndarray]: A dict with stacked arrays for each key. + """ + ret = {} + for key in dict_list[0].keys(): + ret[key] = np.stack([_dict[key] for _dict in dict_list], axis=0) + return ret + + +def typename(obj: object) -> str: + """Return type name of given object. + + Args: + obj (object): Python object which is instantiated from a class. + + Returns: + str: Class name of given object. + """ + return obj.__class__.__name__ + + +def combine_array_with_time(x: np.ndarray, t: Tuple[int, ...]) -> np.ndarray: + """Combine given data x with time sequence t. + Given x with shape (N, D) and t with shape (T, ), + this function will repeat t_i for N times and will concat it with data x for each t_i in t, + finally return the stacked result, whic is of shape (NxT, D+1). + + Args: + x (np.ndarray): Points data with shape (N, D). + t (Tuple[int, ...]): Time sequence with shape (T, ). + + Returns: + np.ndarray: Combined data with shape of (NxT, D+1). + """ + nx = len(x) + tx = [] + for ti in t: + tx.append( + np.hstack( + (np.full([nx, 1], float(ti), dtype=paddle.get_default_dtype()), x) + ) + ) + tx = np.vstack(tx) + return tx + + +def set_random_seed(seed: int): + """Set numpy, random, paddle random_seed to given seed. + + Args: + seed (int): Random seed. + """ + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + + +def run_on_eval_mode(func: Callable) -> Callable: + """A decorator automatically running given class method in eval mode and keep + training state unchanged after function finished. + + Args: + func (Callable): Class method which is expected running in eval mode. + + Returns: + Callable: Decorated class method. + """ + + @functools.wraps(func) + def function_with_eval_state(self, *args, **kwargs): + # log original state + train_state = self.model.training + + # switch to eval mode + if train_state: + self.model.eval() + + # run func in eval mode + result = func(self, *args, **kwargs) + + # restore state + if train_state: + self.model.train() + + return result + + return function_with_eval_state diff --git a/jointContribution/yinglong/ppsci/utils/profiler.py b/jointContribution/yinglong/ppsci/utils/profiler.py new file mode 100644 index 0000000000..4eeb75d863 --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/profiler.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import paddle + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions: + """ + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are "CPU", "GPU" or "All". + sorted_key - a string, the optional values are "calls", "total", + "max", "min" or "ave. + tracer_option - a string, the optional values are "Default", "OpDetail", + "AllOpDetail". + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + """ + + def __init__(self, options_str): + if not isinstance(options_str, str): + raise ValueError() + + self._options = { + "batch_range": [10, 20], + "state": "All", + "sorted_key": "total", + "tracer_option": "Default", + "profile_path": "/tmp/profile", + "exit_on_finished": True, + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + for kv in options_str.replace(" ", "").split(";"): + key, value = kv.split("=") + if key == "batch_range": + value_list = value.replace("[", "").replace("]", "").split(",") + value_list = list(map(int, value_list)) + if ( + len(value_list) >= 2 + and value_list[0] >= 0 + and value_list[1] > value_list[0] + ): + self._options[key] = value_list + elif key == "exit_on_finished": + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in ["state", "sorted_key", "tracer_option", "profile_path"]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError(f"ProfilerOptions does not have an option named {name}.") + return self._options[name] + + +def add_profiler_step(options_str=None): + """ + Enable the operator-level timing using PaddlePaddle"s profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + profiler_options - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + """ + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + + if _profiler_step_id == _profiler_options["batch_range"][0]: + paddle.utils.profiler.start_profiler( + _profiler_options["state"], _profiler_options["tracer_option"] + ) + elif _profiler_step_id == _profiler_options["batch_range"][1]: + paddle.utils.profiler.stop_profiler( + _profiler_options["sorted_key"], _profiler_options["profile_path"] + ) + if _profiler_options["exit_on_finished"]: + sys.exit(0) + + _profiler_step_id += 1 diff --git a/jointContribution/yinglong/ppsci/utils/reader.py b/jointContribution/yinglong/ppsci/utils/reader.py new file mode 100644 index 0000000000..48300ac96b --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/reader.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import csv +import sys +from typing import Dict +from typing import Optional +from typing import Tuple + +import meshio +import numpy as np +import paddle +import scipy.io as sio + +from ppsci.utils import logger + +__all__ = [ + "load_csv_file", + "load_mat_file", + "load_vtk_file", + "load_vtk_with_time_file", +] + + +def load_csv_file( + file_path: str, + keys: Tuple[str, ...], + alias_dict: Optional[Dict[str, str]] = None, + delimeter: str = ",", + encoding: str = "utf-8", +) -> Dict[str, np.ndarray]: + """Load *.csv file and fetch data as given keys. + + Args: + file_path (str): CSV file path. + keys (Tuple[str, ...]): Required fetching keys. + alias_dict (Optional[Dict[str, str]]): Alias for keys, + i.e. {inner_key: outer_key}. Defaults to None. + encoding (str, optional): Encoding code when open file. Defaults to "utf-8". + + Returns: + Dict[str, np.ndarray]: Loaded data in dict. + """ + if alias_dict is None: + alias_dict = {} + + try: + # read all data from csv file + with open(file_path, "r", encoding=encoding) as csv_file: + reader = csv.DictReader(csv_file, delimiter=delimeter) + raw_data = collections.defaultdict(list) + for _, line_dict in enumerate(reader): + for key, value in line_dict.items(): + raw_data[key].append(value) + except FileNotFoundError: + logger.error(f"{file_path} isn't a valid csv file.") + sys.exit() + + # convert to numpy array + data_dict = {} + for key in keys: + fetch_key = alias_dict[key] if key in alias_dict else key + if fetch_key not in raw_data: + raise KeyError(f"fetch_key({fetch_key}) do not exist in raw_data.") + data_dict[key] = np.asarray( + raw_data[fetch_key], paddle.get_default_dtype() + ).reshape([-1, 1]) + + return data_dict + + +def load_mat_file( + file_path: str, keys: Tuple[str, ...], alias_dict: Optional[Dict[str, str]] = None +) -> Dict[str, np.ndarray]: + """Load *.mat file and fetch data as given keys. + + Args: + file_path (str): Mat file path. + keys (Tuple[str, ...]): Required fetching keys. + alias_dict (Optional[Dict[str, str]]): Alias for keys, + i.e. {original_key: original_key}. Defaults to None. + + Returns: + Dict[str, np.ndarray]: Loaded data in dict. + """ + + if alias_dict is None: + alias_dict = {} + + try: + # read all data from mat file + raw_data = sio.loadmat(file_path) + except FileNotFoundError: + logger.error(f"{file_path} isn't a valid mat file.") + raise + + # convert to numpy array + data_dict = {} + for key in keys: + fetch_key = alias_dict[key] if key in alias_dict else key + if fetch_key not in raw_data: + raise KeyError(f"fetch_key({fetch_key}) do not exist in raw_data.") + data_dict[key] = np.asarray( + raw_data[fetch_key], paddle.get_default_dtype() + ).reshape([-1, 1]) + + return data_dict + + +def load_vtk_file( + filename_without_timeid: str, + time_step: float, + time_index: Tuple[int, ...], + input_keys: Tuple[str, ...], + label_keys: Optional[Tuple[str, ...]], +) -> Dict[str, np.ndarray]: + """Load coordinates and attached label from the *.vtu file. + + Args: + filename_without_timeid (str): File name without time id. + time_step (float): Physical time step. + time_index (Tuple[int, ...]): Physical time indexes. + input_keys (Tuple[str, ...]): Input coordinates name keys. + label_keys (Optional[Tuple[str, ...]]): Input label name keys. + + Returns: + Dict[str, np.ndarray]: Input coordinates dict, label coordinates dict + """ + input_dict = {var: [] for var in input_keys} + label_dict = {var: [] for var in label_keys} + for index in time_index: + file = filename_without_timeid + f"{index}.vtu" + mesh = meshio.read(file) + n = mesh.points.shape[0] + i = 0 + for key in input_dict: + if key == "t": + input_dict[key].append(np.full((n, 1), index * time_step, "float32")) + else: + input_dict[key].append( + mesh.points[:, i].reshape(n, 1).astype("float32") + ) + i += 1 + for i, key in enumerate(label_dict): + label_dict[key].append(np.array(mesh.point_data[key], "float32")) + for key in input_dict: + input_dict[key] = np.concatenate(input_dict[key]) + for key in label_dict: + label_dict[key] = np.concatenate(label_dict[key]) + + return input_dict, label_dict + + +def load_vtk_with_time_file(file: str) -> Dict[str, np.ndarray]: + """Temporary interface for points cloud, will be banished sooner. + + Args: + file (str): input file name. + + Returns: + Dict[str, np.ndarray]: Input coordinates dict. + """ + mesh = meshio.read(file) + n = mesh.points.shape[0] + t = np.array(mesh.point_data["time"]) + x = mesh.points[:, 0].reshape(n, 1) + y = mesh.points[:, 1].reshape(n, 1) + z = mesh.points[:, 2].reshape(n, 1) + input_dict = {"t": t, "x": x, "y": y, "z": z} + return input_dict diff --git a/jointContribution/yinglong/ppsci/utils/save_load.py b/jointContribution/yinglong/ppsci/utils/save_load.py new file mode 100644 index 0000000000..d94ed4d2d1 --- /dev/null +++ b/jointContribution/yinglong/ppsci/utils/save_load.py @@ -0,0 +1,150 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any +from typing import Dict + +import paddle + +from ppsci.utils import download +from ppsci.utils import logger + +__all__ = ["load_checkpoint", "save_checkpoint", "load_pretrain"] + + +def _load_pretrain_from_path(model, path, equation=None): + """Load pretrained model from given path. + + Args: + model (nn.Layer): Model with parameters. + path (str, optional): Pretrained model path. + equation (Optional[Dict[str, ppsci.equation.PDE]]): Equations. Defaults to None. + """ + if not (os.path.isdir(path) or os.path.exists(f"{path}.pdparams")): + raise FileNotFoundError( + f"Pretrained model path {path}.pdparams does not exists." + ) + + param_state_dict = paddle.load(f"{path}.pdparams") + model.set_dict(param_state_dict) + if equation is not None: + if not os.path.exists(f"{path}.pdeqn"): + logger.warning(f"{path}.pdeqn not found.") + else: + equation_dict = paddle.load(f"{path}.pdeqn") + for name, _equation in equation.items(): + _equation.set_state_dict(equation_dict[name]) + + logger.info(f"Finish loading pretrained model from {path}") + + +def load_pretrain(model, path, equation=None): + """Load pretrained model from given path or url. + + Args: + model (nn.Layer): Model with parameters. + path (str): Pretrained model url. + equation (Optional[Dict[str, ppsci.equation.PDE]]): Equations. Defaults to None. + """ + if path.startswith("http"): + path = download.get_weights_path_from_url(path).replace(".pdparams", "") + _load_pretrain_from_path(model, path, equation) + + +def load_checkpoint( + path, model, optimizer, grad_scaler=None, equation=None +) -> Dict[str, Any]: + """Load from checkpoint. + + Args: + path (AttrDict): Path for checkpoint. + model (nn.Layer): Model with parameters. + optimizer (optimizer.Optimizer, optional): Optimizer for model. + grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None. + equation (Optional[Dict[str, ppsci.equation.PDE]]): Equations. Defaults to None. + + Returns: + Dict[str, Any]: Loaded metric information. + """ + if not os.path.exists(f"{path}.pdparams"): + raise FileNotFoundError(f"{path}.pdparams not exist.") + if not os.path.exists(f"{path}.pdopt"): + raise FileNotFoundError(f"{path}.pdopt not exist.") + if grad_scaler is not None and not os.path.exists(f"{path}.pdscaler"): + raise FileNotFoundError(f"{path}.scaler not exist.") + + # load state dict + param_dict = paddle.load(f"{path}.pdparams") + optim_dict = paddle.load(f"{path}.pdopt") + metric_dict = paddle.load(f"{path}.pdstates") + if grad_scaler is not None: + scaler_dict = paddle.load(f"{path}.pdscaler") + if equation is not None: + if not os.path.exists(f"{path}.pdeqn"): + logger.warning(f"{path}.pdeqn not found.") + equation_dict = None + else: + equation_dict = paddle.load(f"{path}.pdeqn") + + # set state dict + model.set_state_dict(param_dict) + optimizer.set_state_dict(optim_dict) + if grad_scaler is not None: + grad_scaler.load_state_dict(scaler_dict) + if equation is not None and equation_dict is not None: + for name, _equation in equation.items(): + _equation.set_state_dict(equation_dict[name]) + + logger.info(f"Finish loading checkpoint from {path}") + return metric_dict + + +def save_checkpoint( + model, optimizer, grad_scaler, metric, model_dir, prefix="model", equation=None +): + """Save checkpoint, including model params, optimizer params, metric information. + + Args: + model (nn.Layer): Model with parameters. + optimizer (optimizer.Optimizer): Optimizer for model. + grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None. + metric (Dict[str, float]): Metric information, such as {"RMSE": ...}. + model_dir (str): Directory for chekpoint storage. + prefix (str, optional): Prefix for storage. Defaults to "ppsci". + equation (Optional[Dict[str, ppsci.equation.PDE]]): Equations. Defaults to None. + """ + if paddle.distributed.get_rank() != 0: + return + if model_dir is None: + logger.warning( + f"model_dir({model_dir}) is set to None, skip save_checkpoint..." + ) + return + model_dir = os.path.join(model_dir, "checkpoints") + os.makedirs(model_dir, exist_ok=True) + model_path = os.path.join(model_dir, prefix) + + paddle.save(model.state_dict(), f"{model_path}.pdparams") + paddle.save(optimizer.state_dict(), f"{model_path}.pdopt") + paddle.save(metric, f"{model_path}.pdstates") + if grad_scaler is not None: + paddle.save(grad_scaler.state_dict(), f"{model_path}.pdscaler") + if equation is not None: + paddle.save( + {key: eq.state_dict() for key, eq in equation.items()}, + f"{model_path}.pdeqn", + ) + + logger.info(f"Finish saving checkpoint to {model_path}") diff --git a/jointContribution/yinglong/ppsci/validate/__init__.py b/jointContribution/yinglong/ppsci/validate/__init__.py new file mode 100644 index 0000000000..3d43c893c0 --- /dev/null +++ b/jointContribution/yinglong/ppsci/validate/__init__.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from ppsci.loss import build_loss +from ppsci.metric import build_metric +from ppsci.utils import logger +from ppsci.utils import misc +from ppsci.validate.base import Validator + +# from ppsci.validate.geo_validator import GeometryValidator +from ppsci.validate.sup_validator import SupervisedValidator + +__all__ = [ + "Validator", + # "GeometryValidator", + "SupervisedValidator", +] + + +def build_validator(cfg, equation_dict, geom_dict): + """Build validator(s). + + Args: + cfg (List[AttrDict]): Validator(s) config list. + geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. + equation_dict (Dct[str, Equation]): Equation(s) in dict. + + Returns: + Dict[str, Validator]: Validator(s) in dict. + """ + if cfg is None: + return None + cfg = copy.deepcopy(cfg) + global_dataloader_cfg = cfg["dataloader"] + validator_cfg = cfg["content"] + + validator_dict = misc.PrettyOrderedDict() + for _item in validator_cfg: + validator_cls = next(iter(_item.keys())) + _validator_cfg = _item[validator_cls] + validator_name = _validator_cfg.get("name", validator_cls) + # select geometry + geom_name = _validator_cfg.pop("geom") + _validator_cfg["geom"] = geom_dict[geom_name] + + # update complete dataloader config + local_dataloader_cfg = _validator_cfg["dataloader"] + local_dataloader_cfg.update(global_dataloader_cfg) + + # select equation + for name, expr in _validator_cfg["output_expr"].items(): + if isinstance(expr, str) and expr in equation_dict: + _validator_cfg["output_expr"][name] = equation_dict[expr].equations[ + name + ] + + # build loss + _validator_cfg["loss"] = build_loss(_validator_cfg["loss"]) + + # build metric + _validator_cfg["metric"] = build_metric(_validator_cfg["metric"]) + + # instantiate validator + _validator_cfg["dataloader_cfg"] = _validator_cfg.pop("dataloader") + validator_dict[validator_name] = eval(validator_cls)(**_validator_cfg) + + logger.debug(str(validator_dict[validator_name])) + + return validator_dict diff --git a/jointContribution/yinglong/ppsci/validate/base.py b/jointContribution/yinglong/ppsci/validate/base.py new file mode 100644 index 0000000000..a52b273c8b --- /dev/null +++ b/jointContribution/yinglong/ppsci/validate/base.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Dict + +from paddle import io + +from ppsci import data +from ppsci import loss +from ppsci import metric + + +class Validator: + """Base class for validators. + + Args: + dataset (io.Dataset): Dataset for validator. + dataloader_cfg (Dict[str, Any]): Dataloader config. + loss (loss.Loss): Loss functor. + metric (Dict[str, metric.Metric]): Named metric functors in dict. + name (str): Name of validator. + """ + + def __init__( + self, + dataset: io.Dataset, + dataloader_cfg: Dict[str, Any], + loss: loss.Loss, + metric: Dict[str, metric.Metric], + name: str, + ): + self.data_loader = data.build_dataloader(dataset, dataloader_cfg) + self.data_iter = iter(self.data_loader) + self.loss = loss + self.metric = metric + self.name = name + + def __str__(self): + return ", ".join( + [ + self.__class__.__name__, + f"name = {self.name}", + f"input_keys = {self.input_keys}", + f"output_keys = {self.output_keys}", + f"output_expr = {self.output_expr}", + f"label_dict = {self.label_dict}", + f"len(dataloader) = {len(self.data_loader)}", + f"loss = {self.loss}", + f"metric = {list(self.metric.keys())}", + ] + ) diff --git a/jointContribution/yinglong/ppsci/validate/sup_validator.py b/jointContribution/yinglong/ppsci/validate/sup_validator.py new file mode 100644 index 0000000000..61081cd5ee --- /dev/null +++ b/jointContribution/yinglong/ppsci/validate/sup_validator.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional + +from ppsci import loss +from ppsci import metric +from ppsci.data import dataset +from ppsci.validate import base + + +class SupervisedValidator(base.Validator): + """Validator for supervised models. + + Args: + dataloader_cfg (Dict[str, Any]): Config of building a dataloader. + loss (loss.Loss): Loss functor. + output_expr (Optional[Dict[str, Callable]]): List of label expression. + metric (Optional[Dict[str, metric.Metric]]): Named metric functors in dict. Defaults to None. + name (Optional[str]): Name of validator. Defaults to None. + + Examples: + >>> import ppsci + >>> valida_dataloader_cfg = { + ... "dataset": { + ... "name": "MatDataset", + ... "file_path": "/path/to/file.mat", + ... "input_keys": ("t_f",), + ... "label_keys": ("eta", "f"), + ... }, + ... "batch_size": 32, + ... "sampler": { + ... "name": "BatchSampler", + ... "drop_last": False, + ... "shuffle": False, + ... }, + ... } # doctest: +SKIP + >>> eta_mse_validator = ppsci.validate.SupervisedValidator( + ... valida_dataloader_cfg, + ... ppsci.loss.MSELoss("mean"), + ... {"eta": lambda out: out["eta"]}, + ... metric={"MSE": ppsci.metric.MSE()}, + ... name="eta_mse", + ... ) # doctest: +SKIP + """ + + def __init__( + self, + dataloader_cfg: Dict[str, Any], + loss: loss.Loss, + output_expr: Optional[Dict[str, Callable]] = None, + metric: Optional[Dict[str, metric.Metric]] = None, + name: Optional[str] = None, + ): + self.output_expr = output_expr + + # build dataset + _dataset = dataset.build_dataset(dataloader_cfg["dataset"]) + + self.input_keys = _dataset.input_keys + self.output_keys = ( + list(output_expr.keys()) if output_expr is not None else _dataset.label_keys + ) + + if self.output_expr is None: + self.output_expr = { + key: lambda out, k=key: out[k] for key in self.output_keys + } + + # construct dataloader with dataset and dataloader_cfg + super().__init__(_dataset, dataloader_cfg, loss, metric, name) + + def __str__(self): + return ", ".join( + [ + self.__class__.__name__, + f"name = {self.name}", + f"input_keys = {self.input_keys}", + f"output_keys = {self.output_keys}", + f"output_expr = {self.output_expr}", + f"len(dataloader) = {len(self.data_loader)}", + f"loss = {self.loss}", + f"metric = {list(self.metric.keys())}", + ] + ) diff --git a/jointContribution/yinglong/requirements.txt b/jointContribution/yinglong/requirements.txt new file mode 100644 index 0000000000..5963cb0555 --- /dev/null +++ b/jointContribution/yinglong/requirements.txt @@ -0,0 +1,16 @@ +numpy>=1.20.0 +scipy +sympy +matplotlib +vtk +pyevtk +wget +scipy +visualdl +pyvista==0.37.0 +pyyaml +scikit-optimize +h5py +meshio==5.3.4 +tqdm +imageio diff --git a/jointContribution/yinglong/train.sh b/jointContribution/yinglong/train.sh new file mode 100644 index 0000000000..790a455cc8 --- /dev/null +++ b/jointContribution/yinglong/train.sh @@ -0,0 +1,2 @@ +export PYTHONPATH=$PWD +python -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,4' examples/fourcastnet_hrrr/train_pretrain.py