diff --git a/pdebench/models/unet/train.py b/pdebench/models/unet/train.py index 858f4d0..a34a6f0 100644 --- a/pdebench/models/unet/train.py +++ b/pdebench/models/unet/train.py @@ -88,16 +88,10 @@ def run_training( train_data = UNetDatasetMult( flnm, - reduced_resolution=reduced_resolution, - reduced_resolution_t=reduced_resolution_t, - reduced_batch=reduced_batch, saved_folder=base_path, ) val_data = UNetDatasetMult( flnm, - reduced_resolution=reduced_resolution, - reduced_resolution_t=reduced_resolution_t, - reduced_batch=reduced_batch, if_test=True, saved_folder=base_path, ) @@ -312,19 +306,19 @@ def run_training( loss.backward() optimizer.step() - if training_type in ["single"]: - x = xx[..., 0, :] - y = yy[..., t_train - 1 : t_train, :] - pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) - _batch = yy.size(0) - loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) + if training_type in ["single"]: + x = xx[..., 0, :] + y = yy[..., t_train - 1 : t_train, :] + pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) + _batch = yy.size(0) + loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) - train_l2_step += loss.item() - train_l2_full += loss.item() + train_l2_step += loss.item() + train_l2_full += loss.item() - optimizer.zero_grad() - loss.backward() - optimizer.step() + optimizer.zero_grad() + loss.backward() + optimizer.step() if ep % model_update == 0: val_l2_step = 0 @@ -370,15 +364,15 @@ def run_training( _pred.reshape(_batch, -1), _yy.reshape(_batch, -1) ).item() - if training_type in ["single"]: - x = xx[..., 0, :] - y = yy[..., t_train - 1 : t_train, :] - pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) - _batch = yy.size(0) - loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) + if training_type in ["single"]: + x = xx[..., 0, :] + y = yy[..., t_train - 1 : t_train, :] + pred = model(x.permute([0, 2, 1])).permute([0, 2, 1]) + _batch = yy.size(0) + loss += loss_fn(pred.reshape(_batch, -1), y.reshape(_batch, -1)) - val_l2_step += loss.item() - val_l2_full += loss.item() + val_l2_step += loss.item() + val_l2_full += loss.item() if val_l2_full < loss_val_min: loss_val_min = val_l2_full