diff --git a/docs/zh/examples/meteoformer.md b/docs/zh/examples/meteoformer.md new file mode 100644 index 0000000000..7af06aebb6 --- /dev/null +++ b/docs/zh/examples/meteoformer.md @@ -0,0 +1,189 @@ +# Meteoformer + +开始训练、评估前,请下载ERA5数据集文件 + +开始评估前,请下载或训练生成预训练模型 + +=== "模型训练命令" + + ``` sh + python main.py + ``` + +=== "模型评估命令" + + ``` sh + python main.py mode=eval + ``` + +## 1. 背景简介 + +短中期气象预测主要涉及对未来几小时至几天内的天气变化进行预测。这类预测通常需要涵盖多个气象要素,如温度、湿度、风速等,这些要素对气象变化有着复杂的时空依赖关系。准确的短中期气象预测对于防灾减灾、农业生产、航空航天等领域具有重要意义。传统的气象预测模型主要依赖于物理公式和数值天气预报(NWP),但随着深度学习的快速发展,基于数据驱动的模型逐渐展现出更强的预测能力。 + +为了有效捕捉这些多维时空特征,Meteoformer应运而生。Meteoformer是一种基于Transformer架构的模型,专门针对短中期多气象要素的预测任务进行优化。该模型能够处理多个气象变量的时空依赖关系,采用自注意力机制来捕捉不同时空尺度的关联性,从而实现更准确的温度、湿度、风速等气象要素的多步预测。通过Meteoformer,气象预报可以实现更加高效和精确的多要素预测,为气象服务提供更加可靠的数据支持。 + + +## 2. 模型原理 + +本章节对 Meteoformer 的模型原理进行简单地介绍。 + +### 2.1 编码器 + +该模块使用两层Transformer,提取空间特征更新节点特征: + +``` py linenums="8" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:233:267 +--8<-- +``` + +### 2.2 演变器 + +该模块使用两层Transformer,学习全局时间动态特性: + +``` py linenums="29" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:269:314 +--8<-- +``` + +### 2.3 解码器 + +该模块使用两层卷积,将时空表征解码为未来多气象要素: + +``` py linenums="29" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:317:332 +--8<-- +``` + +### 2.4 Meteoformer模型结构 + +Meteoformer模型首先使用特征嵌入层对输入信号(过去几个时间帧的气象要素)进行空间特征编码: + +``` py linenums="73" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:405:406 +--8<-- +``` + +``` py linenums="94" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:233:267 +--8<-- +``` + +然后模型利用演变器将学习空间特征的动态特性,预测未来几个时间帧的气象特征: + +``` py linenums="75" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:409:411 +--8<-- +``` + +``` py linenums="96" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:269:314 +--8<-- +``` + +最后模型将时空动态特性与初始气象底层特征结合,使用两层卷积预测未来短中期内的多气象要素值: + +``` py linenums="112" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:414:415 +--8<-- +``` + +``` py linenums="35" title="ppsci/arch/Meteoformer.py" +--8<-- +ppsci/arch/Meteoformer.py:317:332 +--8<-- +``` + +## 3. 模型训练 + +### 3.1 数据集介绍 + +案例中使用了预处理的ERA5Meteo数据集,属于ERA5再分析数据的一个子集。ERA5Meteo包含了全球大气、陆地和海洋的多种变量,分辨率为31公里。该数据集从1979年开始到2018年,每小时提供一次天气状况的估计,非常适合用于短中期多气象要素预测等任务。在实际应用过程中,时间间隔选取为6小时。 + +数据集被保存为 T x C x H x W 的矩阵,记录了相应地点和时间的对应气象要素的值,其中 T 为时间序列长度,C代表通道维,案例中选取了3个不同气压层的温度、相对湿度、东向风速、北向风速等气象信息,H 和 W 代表按照经纬度划分后的矩阵的高度和宽度。根据年份,数据集按照 7:2:1 划分为训练集、验证集,和测试集。案例中预先计算了气象要素数据的均值与标准差,用于后续的正则化操作。 + +### 3.2 模型训练 + +#### 3.2.1 模型构建 + +该案例基于 Meteoformer 模型实现,用 PaddleScience 代码表示如下: + +``` py linenums="79" title="examples/Meteoformer/mian.py" +--8<-- +examples/Meteoformer/main.py:92:92 +--8<-- +``` + +#### 3.2.2 约束器构建 + +本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 `SupervisedConstraint` 构建监督约束器。在定义约束器之前,需要首先指定约束器中用于数据加载的各个参数。 + +训练集数据加载的代码如下: + +``` py linenums="20" title="examples/Meteoformer/main.py" +--8<-- +examples/Meteoformer/main.py:23:38 +--8<-- +``` + +定义监督约束的代码如下: + +``` py linenums="40" title="examples/Meteoformer/main.py" +--8<-- +examples/Meteoformer/main.py:57:61 +--8<-- +``` + +#### 3.2.3 评估器构建 + +本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 `SupervisedValidator` 构建评估器。 + +验证集数据加载的代码如下: + +``` py linenums="44" title="examples/Meteoformer/main.py" +--8<-- +examples/Meteoformer/main.py:68:78 +--8<-- +``` + +定义监督评估器的代码如下: + +``` py linenums="65" title="examples/Meteoformer/main.py" +--8<-- +examples/Meteoformer/main.py:81:88 +--8<-- +``` + +#### 3.2.4 学习率与优化器构建 + +本案例中学习率大小设置为 `1e-3`,优化器使用 `Adam`,用 PaddleScience 代码表示如下: + +``` py linenums="83" title="examples/Meteoformer/main.py" +--8<-- +examples/Meteoformer/main.py:95:99 +--8<-- +``` + +#### 3.2.5 模型训练 + +完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练。 + +``` py linenums="88" title="examples/Meteoformer/main.py" +--8<-- +examples/Meteoformer/main.py:115:117 +--8<-- +``` + +## 4. 完整代码 + +``` py linenums="1" title="examples/Meteoformer/main.py" +--8<-- +examples/Meteoformer/main.py +--8<-- \ No newline at end of file diff --git a/examples/meteoformer/conf/train.yaml b/examples/meteoformer/conf/train.yaml new file mode 100644 index 0000000000..aa59573ea6 --- /dev/null +++ b/examples/meteoformer/conf/train.yaml @@ -0,0 +1,71 @@ +defaults: + - ppsci_default + - TRAIN: train_default + - TRAIN/ema: ema_default + - TRAIN/swa: swa_default + - EVAL: eval_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_meteoformer + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 1024 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# set training hyper-parameters +SQ_LEN: 12 +IMG_H: 192 +IMG_W: 256 +USE_SAMPLED_DATA: false + +# set train data path +TRAIN_FILE_PATH: /patch/to/ERA5/ +DATA_MEAN_PATH: examples/weather/datasets/era5/stat/mean.nc +DATA_STD_PATH: examples/weather/datasets/era5/stat/std.nc + +# set evaluate data path +VALID_FILE_PATH: /patch/to/ERA5/ + +# model settings +MODEL: + input_keys: ["input"] + output_keys: ["output"] + shape_in: + - 12 + - 12 + - ${IMG_H} + - ${IMG_W} + +# training settings +TRAIN: + epochs: 150 + save_freq: 20 + eval_during_train: true + eval_freq: 20 + lr_scheduler: + epochs: ${TRAIN.epochs} + learning_rate: 0.001 + by_epoch: true + batch_size: 16 + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + compute_metric_by_batch: true + eval_with_no_grad: true + batch_size: 16 diff --git a/examples/meteoformer/datasets/era5/stat/mean.nc b/examples/meteoformer/datasets/era5/stat/mean.nc new file mode 100644 index 0000000000..792096f663 Binary files /dev/null and b/examples/meteoformer/datasets/era5/stat/mean.nc differ diff --git a/examples/meteoformer/datasets/era5/stat/std.nc b/examples/meteoformer/datasets/era5/stat/std.nc new file mode 100644 index 0000000000..9e288e39f7 Binary files /dev/null and b/examples/meteoformer/datasets/era5/stat/std.nc differ diff --git a/examples/meteoformer/main.py b/examples/meteoformer/main.py new file mode 100644 index 0000000000..c499f1ed41 --- /dev/null +++ b/examples/meteoformer/main.py @@ -0,0 +1,169 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hydra +import utils as utils +from omegaconf import DictConfig +import ppsci + +def train(cfg: DictConfig): + # set train dataloader config + if not cfg.USE_SAMPLED_DATA: + train_dataloader_cfg = { + "dataset": { + "name": "ERA5MeteoDataset", + "file_path": cfg.TRAIN_FILE_PATH, + "input_keys": cfg.MODEL.afno.input_keys, + "label_keys": cfg.MODEL.afno.output_keys, + "size": (cfg.IMG_H, cfg.IMG_W), + }, + "sampler": { + "name": "BatchSampler", + "drop_last": True, + "shuffle": True, + }, + "batch_size": cfg.TRAIN.batch_size, + "num_workers": 1, + } + else: + NUM_GPUS_PER_NODE = 8 + train_dataloader_cfg = { + "dataset": { + "name": "ERA5SampledDataset", + "file_path": cfg.TRAIN_FILE_PATH, + "input_keys": cfg.MODEL.afno.input_keys, + "label_keys": cfg.MODEL.afno.output_keys, + }, + "sampler": { + "name": "DistributedBatchSampler", + "drop_last": True, + "shuffle": True, + }, + "batch_size": cfg.TRAIN.batch_size, + "num_workers": 1, + } + # set constraint + sup_constraint = ppsci.constraint.SupervisedConstraint( + train_dataloader_cfg, + ppsci.loss.MSELoss(), + name="Sup", + ) + constraint = {sup_constraint.name: sup_constraint} + + # set iters_per_epoch by dataloader length + ITERS_PER_EPOCH = len(sup_constraint.data_loader) + + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "ERA5MeteoDataset", + "file_path": cfg.VALID_FILE_PATH, + "input_keys": cfg.MODEL.afno.input_keys, + "label_keys": cfg.MODEL.afno.output_keys, + "training": False, + "size": (cfg.IMG_H, cfg.IMG_W), + }, + "batch_size": cfg.EVAL.batch_size, + } + + # set validator + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + ppsci.loss.MSELoss(), + metric={ + "MAE": ppsci.metric.MAE(keep_batch=True), + }, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + # set model + model = ppsci.arch.Meteoformer(**cfg.MODEL) + + # init optimizer and lr scheduler + lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler) + lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH}) + lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)() + + optimizer = ppsci.optimizer.Adam(lr_scheduler)(model) + + # initialize solver + solver = ppsci.solver.Solver( + model, + constraint, + cfg.output_dir, + optimizer, + epochs=cfg.TRAIN.epochs, + iters_per_epoch=ITERS_PER_EPOCH, + eval_during_train=cfg.TRAIN.compute_metric_by_batch, + validator=validator, + compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # train model + solver.train() + # evaluate after finished training + solver.eval() + +def evaluate(cfg: DictConfig): + # set eval dataloader config + eval_dataloader_cfg = { + "dataset": { + "name": "ERA5MeteoDataset", + "file_path": cfg.VALID_FILE_PATH, + "input_keys": cfg.MODEL.afno.input_keys, + "label_keys": cfg.MODEL.afno.output_keys, + "training": False, + "size": (cfg.IMG_H, cfg.IMG_W), + }, + "batch_size": cfg.EVAL.batch_size, + } + + # set validator + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + ppsci.loss.MSELoss(), + metric={ + "MAE": ppsci.metric.MAE(keep_batch=True), + }, + name="Sup_Validator", + ) + validator = {sup_validator.name: sup_validator} + + # set model + model = ppsci.arch.Meteoformer(**cfg.MODEL.afno) + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + log_freq=cfg.log_freq, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + # evaluate + solver.eval() + +@hydra.main(version_base=None, config_path="./conf", config_name="train.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + +if __name__ == "__main__": + main() diff --git a/examples/meteoformer/utils.py b/examples/meteoformer/utils.py new file mode 100644 index 0000000000..c766471e98 --- /dev/null +++ b/examples/meteoformer/utils.py @@ -0,0 +1,22 @@ +from datetime import datetime +from typing import Tuple + +import xarray as xr + + +def date_to_hours(date: str): + date_obj = datetime.strptime(date, "%Y-%m-%d %H:%M:%S") + day_of_year = date_obj.timetuple().tm_yday - 1 + hour_of_day = date_obj.timetuple().tm_hour + hours_since_jan_01_epoch = 24 * day_of_year + hour_of_day + return hours_since_jan_01_epoch + + +def get_mean_std(mean_path: str, std_path: str, vars_channel: Tuple[int, ...]): + data_mean = xr.open_mfdataset(mean_path)["mean"].values + data_std = xr.open_mfdataset(std_path)["std"].values + + data_mean.resize(data_mean.shape[0], 1, 1) + data_std.resize(data_std.shape[0], 1, 1) + + return data_mean, data_std diff --git a/ppsci/arch/meteoformer.py b/ppsci/arch/meteoformer.py new file mode 100644 index 0000000000..fa74f49883 --- /dev/null +++ b/ppsci/arch/meteoformer.py @@ -0,0 +1,424 @@ +from typing import Optional +from typing import Tuple +import numpy as np +from paddle import nn +from ppsci.arch import base + +def stride_generator(N, reverse=False): + strides = [1, 2] * 10 + if reverse: + return list(reversed(strides[:N])) + else: + return strides[:N] + +class ConvSC(nn.Layer): + def __init__(self, C_in: int, C_out: int, stride: int, transpose: bool = False): + super(ConvSC, self).__init__() + if stride == 1: + transpose = False + if not transpose: + self.conv = nn.Conv2D( + C_in, + C_out, + kernel_size=3, + stride=stride, + padding=1, + weight_attr=nn.initializer.KaimingNormal(), + ) + else: + self.conv = nn.Conv2DTranspose( + C_in, + C_out, + kernel_size=3, + stride=stride, + padding=1, + output_padding=stride // 2, + weight_attr=nn.initializer.KaimingNormal(), + ) + self.norm = nn.GroupNorm(2, C_out) + self.act = nn.LeakyReLU(0.2) + + def forward(self, x): + y = self.conv(x) + y = self.act(self.norm(y)) + return y + +class OverlapPatchEmbed(nn.Layer): + """Image to Patch Embedding""" + + def __init__( + self, + img_size: int = 224, + patch_size: int = 7, + stride: int = 4, + in_chans: int = 3, + embed_dim: int = 768, + ): + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + + self.img_size = img_size + self.patch_size = patch_size + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2D( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2).transpose(perm=[0, 2, 1]) + x = self.norm(x) + + return x, H, W + +class DWConv(nn.Layer): + def __init__(self, dim: int = 768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(perm=[0, 2, 1]).reshape([B, C, H, W]) + x = self.dwconv(x) + x = x.flatten(2).transpose(perm=[0, 2, 1]) + + return x + +class Mlp(nn.Layer): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: nn.Layer = nn.GELU, + drop: float = 0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Layer): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: Optional[int] = None, + qk_scale: Optional[int] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + sr_ratio: float = 1.0, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias_attr=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(axis=-1) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + def forward(self, x, H, W): + B, N, C = x.shape + q = ( + self.q(x) + .reshape([B, N, self.num_heads, C // self.num_heads]) + .transpose(perm=[0, 2, 1, 3]) + ) + + if self.sr_ratio > 1: + x_ = x.transpose(perm=[0, 2, 1]).reshape([B, C, H, W]) + x_ = self.sr(x_).reshape([B, C, -1]).transpose(perm=[0, 2, 1]) + x_ = self.norm(x_) + kv = ( + self.kv(x_) + .reshape([B, -1, 2, self.num_heads, C // self.num_heads]) + .transpose(perm=[2, 0, 3, 1, 4]) + ) + else: + kv = ( + self.kv(x) + .reshape([B, -1, 2, self.num_heads, C // self.num_heads]) + .transpose(perm=[2, 0, 3, 1, 4]) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(perm=[0, 1, 3, 2])) * self.scale + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(perm=[0, 2, 1, 3]).reshape([B, N, C]) + x = self.norm(x) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class Block(nn.Layer): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: Optional[int] = None, + qk_scale: Optional[int] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: nn.Layer = nn.GELU, + norm_layer: nn.Layer = nn.LayerNorm, + sr_ratio: float = 1.0, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + + return x + +class Encoder(nn.Layer): + def __init__(self, C_in: int, C_hid: int, N_S: int): + super(Encoder, self).__init__() + strides = stride_generator(N_S) + + self.enc0 = ConvSC(C_in, C_hid, stride=strides[0]) + self.enc1 = OverlapPatchEmbed( + img_size=256, patch_size=7, stride=4, in_chans=C_hid, embed_dim=C_hid + ) + self.enc2 = Block( + dim=C_hid, + num_heads=1, + mlp_ratio=4, + qkv_bias=None, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=8, + ) + self.norm1 = nn.LayerNorm(C_hid) + + def forward(self, x): # B*4, 3, 128, 128 + B = x.shape[0] + latent = [] + x = self.enc0(x) + latent.append(x) + x, H, W = self.enc1(x) + x = self.enc2(x, H, W) + x = self.norm1(x) + x = x.reshape([B, H, W, -1]).transpose(perm=[0, 3, 1, 2]).contiguous() + latent.append(x) + + return latent + +class MidXnet(nn.Layer): + def __init__( + self, + channel_in: int, + channel_hid: int, + N_T: int, + incep_ker: Tuple[int, ...] = [3, 5, 7, 11], + groups: int = 8, + ): + super(MidXnet, self).__init__() + + self.N_T = N_T + dpr = [x.item() for x in np.linspace(0, 0.1, N_T)] + enc_layers = [] + for i in range(N_T): + enc_layers.append( + Block( + dim=channel_in, + num_heads=4, + mlp_ratio=4, + qkv_bias=None, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=dpr[i], + norm_layer=nn.LayerNorm, + sr_ratio=8, + ) + ) + + self.enc = nn.Sequential(*enc_layers) + + def forward(self, x): + B, T, C, H, W = x.shape + # B TC H W + + x = x.reshape([B, T * C, H, W]) + # B HW TC + x = x.flatten(2).transpose(perm=[0, 2, 1]) + + # encoder + z = x + for i in range(self.N_T): + z = self.enc[i](z, H, W) + + return z + +# MultiDecoder +class Decoder(nn.Layer): + def __init__(self, C_hid: int, C_out: int, N_S: int): + super(Decoder, self).__init__() + strides = stride_generator(N_S, reverse=True) + # strides = [2, 1, 2, 1] + self.dec = nn.Sequential( + *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]], + ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True), + ) + self.readout = nn.Conv2D(C_hid, C_out, 1) + + def forward(self, hid, enc1=None): + for i in range(0, len(self.dec)): + hid = self.dec[i](hid) + Y = self.readout(hid) + return Y + +class Meteoformer(base.Arch): + """ + Meteoformer is a class that represents a Spatial-Temporal Transformer model designed for short-to-medium-term weather prediction with multiple meteorological variables. + + Args: + input_keys (Tuple[str, ...]): A tuple of input keys. + output_keys (Tuple[str, ...]): A tuple of output keys. + shape_in (Tuple[int, ...]): The shape of the input data (T, C, H, W), where + T is the number of time steps, C is the number of channels, + H and W are the spatial dimensions. + hid_S (int): The number of hidden channels in the spatial encoder. + hid_T (int): The number of hidden units in the temporal encoder. + N_S (int): The number of spatial transformer layers. + N_T (int): The number of temporal transformer layers. + incep_ker (Tuple[int, ...]): The kernel sizes used in the inception block. + groups (int): The number of groups for grouped convolutions. + num_classes (int): The number of predicted meteorological variables. + + Examples: + >>> import paddle + >>> import ppsci + >>> model = ppsci.arch.Meteoformer( + ... input_keys=("input",), + ... output_keys=("output",), + ... shape_in=(12, 12, 192, 256), + ... hid_S=64, + ... hid_T=256, + ... N_S=4, + ... N_T=4, + ... incep_ker=(3, 5, 7, 11), + ... groups=8, + ... num_classes=4, + ... ) + >>> input_dict = {"input": paddle.rand([16, 12, 4, 192, 256])} + >>> output_dict = model(input_dict) + >>> print(output_dict["output"].shape) + [16, 12, 4, 192, 256] + """ + + def __init__( + self, + input_keys: Tuple[str, ...], + output_keys: Tuple[str, ...], + shape_in: Tuple[int, ...], + hid_S: int = 64, + hid_T: int = 256, + N_S: int = 4, + N_T: int = 4, + incep_ker: Tuple[int, ...] = [3, 5, 7, 11], + groups: int = 8, + num_classes: int = 4, + ): + super(Meteoformer, self).__init__() + self.input_keys = input_keys + self.output_keys = output_keys + + T, C, H, W = shape_in + self.enc = Encoder(C, hid_S, N_S) + self.hid1 = MidXnet(T * hid_S, hid_T // 2, N_T, incep_ker, groups) + self.dec = Decoder(T * hid_S, T * num_classes, N_S) + + def forward(self, x): + if self._input_transform is not None: + x = self._input_transform(x) + + x = self.concat_to_tensor(x, self.input_keys) + + B, T, C, H, W = x.shape + x = x.reshape([B * T, C, H, W]) + + # encoded + embed = self.enc(x) + _, C_4, H_4, W_4 = embed[-1].shape + + # translator + z = embed[-1].reshape([B, T, C_4, H_4, W_4]) + hid = self.hid1(z) + hid = hid.transpose(perm=[0, 2, 1]).reshape([B, -1, H_4, W_4]) + + # decoded + y = self.dec(hid, embed[0]) + y = y.reshape([B, T, 4, H, W]) + + y = self.split_to_dict(y, self.output_keys) + + if self._output_transform is not None: + y = self._output_transform(x, y) + + return y + + diff --git a/ppsci/data/dataset/era5meteo_dataset.py b/ppsci/data/dataset/era5meteo_dataset.py new file mode 100644 index 0000000000..4e2bb90eab --- /dev/null +++ b/ppsci/data/dataset/era5meteo_dataset.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import datetime +import numbers +import os +import random +from typing import Dict +from typing import Optional +from typing import Tuple + +import h5py +import numpy as np +import paddle +from paddle import io +from paddle import vision + + +class ERA5MeteoDataset(io.Dataset): + """Class for ERA5 dataset. + + Args: + file_path (str): Dataset path. + input_keys (Tuple[str, ...]): Input keys, such as ("input",). + label_keys (Tuple[str, ...]): Output keys, such as ("output",). + precip_file_path (Optional[str]): Precipitation data set path. Defaults to None. + weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None. + vars_channel (Optional[Tuple[int, ...]]): The variable channel index in ERA5 dataset. Defaults to None. + num_label_timestamps (int, optional): Number of timestamp of label. Defaults to 1. + transforms (Optional[vision.Compose]): Compose object contains sample wise + transform(s). Defaults to None. + training (bool, optional): Whether in train mode. Defaults to True. + stride (int, optional): Stride of sampling data. Defaults to 1. + + Examples: + >>> import ppsci + >>> dataset = ppsci.data.dataset.ERA5MeteoDataset( + ... "file_path": "/path/to/ERA5MeteoDataset", + ... "input_keys": ("input",), + ... "label_keys": ("output",), + ... "size": ("H", "W"), + ... ) # doctest: +SKIP + """ + + # Whether support batch indexing for speeding up fetching process. + batch_index: bool = False + + def __init__( + self, + file_path: str, + input_keys: Tuple[str, ...], + label_keys: Tuple[str, ...], + size: Tuple[int, ...], + weight_dict: Optional[Dict[str, float]] = None, + transforms: Optional[vision.Compose] = None, + training: bool = True, + stride: int = 1, + sq_length: int = 6, + time_step: int = 6, + ): + super().__init__() + self.file_path = file_path + self.input_keys = input_keys + self.label_keys = label_keys + self.size = size + self.training = training + self.sq_length = sq_length + + self.time_step = time_step + + self.transforms = transforms + + self.weight_dict = {} if weight_dict is None else weight_dict + if weight_dict is not None: + self.weight_dict = {key: 1.0 for key in self.label_keys} + self.weight_dict.update(weight_dict) + + # load precipitation data + if training: + self.precipitation = h5py.File( + os.path.join(self.file_path, "train", "rain_2016_01.h5") + ) + else: + self.precipitation = h5py.File( + os.path.join(self.file_path, "test", "rain_2020_02.h5") + ) + + t_list = self.precipitation["time"][:] + start_time = datetime.datetime(1900, 1, 1, 0, 0, 0) + self.time_table = [] + for i in range(len(t_list)): + temp = start_time + datetime.timedelta(hours=int(t_list[i])) + self.time_table.append(temp) + + def __len__(self): + return len(self.time_table) - 24 + + def __getitem__(self, global_idx): + + x, y_t, y_r, y_u, y_v = [], [], [], [], [] + + for m in range(self.sq_length): + x.append(self.load_data(global_idx + m * self.time_step)) + for n in range(self.sq_length): + future_data = self.load_data(global_idx + (self.sq_length + n) * self.time_step) + y_t.append(future_data[1]) # Temperature + y_r.append(future_data[0]) # Humidity + y_u.append(future_data[2]) # U-Wind + y_v.append(future_data[3]) # V-Wind + + x, y_t, y_r, y_u, y_v = self._random_crop(x, y_t, y_r, y_u, y_v) + + input_item = {self.input_keys[0]: np.stack(x, axis=0)} + label_item = {self.label_keys[0]: np.stack([y_t, y_r, y_u, y_v], axis=0)} + + weight_shape = [1] * len(next(iter(label_item.values())).shape) + weight_item = { + key: np.full(weight_shape, value, paddle.get_default_dtype()) + for key, value in self.weight_dict.items() + } + + if self.transforms is not None: + input_item, label_item, weight_item = self.transforms( + input_item, label_item, weight_item + ) + + return input_item, label_item, weight_item + + def load_data(self, indices): + year = str(self.time_table[indices].timetuple().tm_year) + mon = str(self.time_table[indices].timetuple().tm_mon) + if len(mon) == 1: + mon = "0" + mon + day = str(self.time_table[indices].timetuple().tm_mday) + if len(day) == 1: + day = "0" + day + hour = str(self.time_table[indices].timetuple().tm_hour) + if len(hour) == 1: + hour = "0" + hour + r_data = np.load(os.path.join(self.file_path, year, f"r_{year}{mon}{day}{hour}.npy")) + t_data = np.load(os.path.join(self.file_path, year, f"t_{year}{mon}{day}{hour}.npy")) + u_data = np.load(os.path.join(self.file_path, year, f"u_{year}{mon}{day}{hour}.npy")) + v_data = np.load(os.path.join(self.file_path, year, f"v_{year}{mon}{day}{hour}.npy")) + + data = np.concatenate([r_data, t_data, u_data, v_data]) + + return data + + def _random_crop(self, x, y_t, y_r, y_u, y_v): + if isinstance(self.size, numbers.Number): + self.size = (int(self.size), int(self.size)) + th, tw = self.size + h, w = y_t[0].shape[-2], y_t[0].shape[-1] + x1, y1 = random.randint(0, w - tw), random.randint(0, h - th) + + # Apply cropping + x = [self._crop(xi, y1, x1, y1 + th, x1 + tw) for xi in x] + y_t = [self._crop(y, y1, x1, y1 + th, x1 + tw) for y in y_t] + y_r = [self._crop(y, y1, x1, y1 + th, x1 + tw) for y in y_r] + y_u = [self._crop(y, y1, x1, y1 + th, x1 + tw) for y in y_u] + y_v = [self._crop(y, y1, x1, y1 + th, x1 + tw) for y in y_v] + + return x, y_t, y_r, y_u, y_v + + def _crop(self, im, x_start, y_start, x_end, y_end): + if len(im.shape) == 3: + return im[:, x_start:x_end, y_start:y_end] + else: + return im[x_start:x_end, y_start:y_end]