From 6d5488cccd413fb1c9c93e2f1ee7634f656f4783 Mon Sep 17 00:00:00 2001 From: till-m Date: Mon, 13 Oct 2025 17:10:55 +0200 Subject: [PATCH 1/2] fix indentation --- pdebench/models/unet/train.py | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/pdebench/models/unet/train.py b/pdebench/models/unet/train.py index 858f4d0..822a698 100644 --- a/pdebench/models/unet/train.py +++ b/pdebench/models/unet/train.py @@ -312,19 +312,19 @@ def run_training( loss.backward() optimizer.step() - if training_type in ["single"]: - x = xx[..., 0, :] - y = yy[..., t_train - 1 : t_train, :] - pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) - _batch = yy.size(0) - loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) + if training_type in ["single"]: + x = xx[..., 0, :] + y = yy[..., t_train - 1 : t_train, :] + pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) + _batch = yy.size(0) + loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) - train_l2_step += loss.item() - train_l2_full += loss.item() + train_l2_step += loss.item() + train_l2_full += loss.item() - optimizer.zero_grad() - loss.backward() - optimizer.step() + optimizer.zero_grad() + loss.backward() + optimizer.step() if ep % model_update == 0: val_l2_step = 0 @@ -370,15 +370,15 @@ def run_training( _pred.reshape(_batch, -1), _yy.reshape(_batch, -1) ).item() - if training_type in ["single"]: - x = xx[..., 0, :] - y = yy[..., t_train - 1 : t_train, :] - pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) - _batch = yy.size(0) - loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) + if training_type in ["single"]: + x = xx[..., 0, :] + y = yy[..., t_train - 1 : t_train, :] + pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) + _batch = yy.size(0) + loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) - val_l2_step += loss.item() - val_l2_full += loss.item() + val_l2_step += loss.item() + val_l2_full += loss.item() if val_l2_full < loss_val_min: loss_val_min = val_l2_full From 08a769de77e938ade65882dfb40ea70e6b6bee5c Mon Sep 17 00:00:00 2001 From: till-m Date: Mon, 13 Oct 2025 17:11:04 +0200 Subject: [PATCH 2/2] fix kwargs --- pdebench/models/unet/train.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pdebench/models/unet/train.py b/pdebench/models/unet/train.py index 822a698..a34a6f0 100644 --- a/pdebench/models/unet/train.py +++ b/pdebench/models/unet/train.py @@ -88,16 +88,10 @@ def run_training( train_data = UNetDatasetMult( flnm, - reduced_resolution=reduced_resolution, - reduced_resolution_t=reduced_resolution_t, - reduced_batch=reduced_batch, saved_folder=base_path, ) val_data = UNetDatasetMult( flnm, - reduced_resolution=reduced_resolution, - reduced_resolution_t=reduced_resolution_t, - reduced_batch=reduced_batch, if_test=True, saved_folder=base_path, )