diff --git a/pufferlib/extensions/pufferlib.cpp b/pufferlib/extensions/pufferlib.cpp index fadd2858e..675f4d75d 100644 --- a/pufferlib/extensions/pufferlib.cpp +++ b/pufferlib/extensions/pufferlib.cpp @@ -366,6 +366,7 @@ void train_impl(PuffeRL& pufferl) { Muon* muon = pufferl.muon; int total_epochs = hypers.total_timesteps / batch_size; + if (total_epochs < 1) total_epochs = 1; if (anneal_lr) { float lr_min = hypers.min_lr_ratio * hypers.lr;