Skip to content

[BUG] Evo2 finetune with LoRA does not work #1136

@dorotat-nv

Description

@dorotat-nv

BioNeMo Framework Version

0d162d5

Bug Description

When running Evo2 finetune with LoRA, the metrics (e.g., val_loss, other validation metrics, and reduced_train_loss) do not decrease as expected and instead remain almost fixed across training steps. The correct behavior is observed for Evo2 finetunes without LoRA - ie the metrics decrease. The Evo2 finetunes both with and without LoRA are defined in link

Image

The PR with the external contribution that was merged https://github.com/NVIDIA/bionemo-framework/pull/980 

  1. Did not include any unit tests to verify LoRA functionality for Evo2, even though it was shown to work on a single GPU (single node).
  2. Was not validated against finetuning without LoRA — not even by the author. There is no supporting evidence that it runs correctly or produces meaningful results, making it highly experimental.

Tasks to do

  1. Create a unit test demonstrating that LoRA finetuning is failing.

  2. Fix the LoRA finetuning functionality.

  3. Ensure the unit test passes once the fix is in place.

  4. If the issue cannot be resolved immediately, add a notification to the training script indicating that LoRA training is currently broken.

Steps to Reproduce

  1. Build a container for 0d162d5ffe97eee471d259e13d82f212f5b4dc5d

  2. Set up job (1 node) on eos and mount /data/:/lustre/fsw/healthcareeng_bionemo/jet/data

  3. Run the finetuning command for Evo2 for LoRA

train_evo2 -d /data/evo2/training_data_config.yaml \ --dataset-dir=/data/evo2/preprocessed_data \ --ckpt-dir=/data/evo2/checkpoints/nemo2_evo2_1b_8k \ --lora-finetune \ --model-size=1b \ --max-steps=1000 \ --experiment-name=evo2-lora-finetune_2bs_1node_1gpu_1000s \ --lr=1.5e-05 \ --min-lr=1.49e-05 \ --warmup-steps=10 \ --result-dir=/jet/logs/recipe/model-evo2_variant-lora-finetune_bionemo_partial-conv_nodes-1_gpus-1_batch_size-2_config_name-1b_max_steps-1000_stop_steps-100_target-dgxh100-eos_task-lora-finetune-from-ckpt/tensorboard_logs \ --micro-batch-size=2 \ --grad-acc-batches=4 \ --limit-val-batches=20 \ --seq-length=8192 \ --clip-grad=250 \ --wd=0.001 \ --attention-dropout=0.01 \ --hidden-dropout=0.01 \ --num-layers 4 \ --hybrid-override-pattern 'SDH*' \ --devices=1 \ --num-nodes=1 \ --val-check-interval=5 \ --wandb-project=jet--partial-conv--main \ --wandb-group=evo2_lora_finetune_1b_lora_finetune_from_ckpt_dgxh100_eos \ --create-tensorboard-logger --activation-checkpoint-recompute-num-layers=2 \ --disable-checkpointing \ --early-stop-on-step=100 \ --garbage-collect-at-inference

Error Messages and Logs

Training logs 

......
0: Training epoch 0, iteration 0/99 | lr: 0 | global_batch_size: 8 | global_step: 0 | reduced_train_loss: 6.278 | train_step_timing in s: 3.453
0: Training epoch 0, iteration 1/99 | lr: 1.5e-06 | global_batch_size: 8 | global_step: 1 | reduced_train_loss: 6.3 | train_step_timing in s: 0.1276 | consumed_samples: 16
0: Training epoch 0, iteration 2/99 | lr: 3e-06 | global_batch_size: 8 | global_step: 2 | reduced_train_loss: 6.283 | train_step_timing in s: 0.1269 | consumed_samples: 24
0: Training epoch 0, iteration 3/99 | lr: 4.5e-06 | global_batch_size: 8 | global_step: 3 | reduced_train_loss: 6.285 | train_step_timing in s: 0.1271 | consumed_samples: 32
0: Training epoch 0, iteration 4/99 | lr: 6e-06 | global_batch_size: 8 | global_step: 4 | reduced_train_loss: 6.296 | train_step_timing in s: 0.1269 | consumed_samples: 40
0: [NeMo W 2025-09-06 12:29:56 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('global_batch_size', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
0:     
0: [NeMo W 2025-09-06 12:29:56 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
0:     
0: Validation: iteration 1/20
0: Validation: iteration 2/20
0: Validation: iteration 3/20
0: Validation: iteration 4/20
0: Validation: iteration 5/20
0: Validation: iteration 6/20
0: Validation: iteration 7/20
0: Validation: iteration 8/20
0: Validation: iteration 9/20
0: Validation: iteration 10/20
0: Validation: iteration 11/20
0: Validation: iteration 12/20
0: Validation: iteration 13/20
0: Validation: iteration 14/20
0: Validation: iteration 15/20
0: Validation: iteration 16/20
0: Validation: iteration 17/20
0: Validation: iteration 18/20
0: Validation: iteration 19/20
0: Validation: iteration 20/20
0: Training epoch 0, iteration 5/99 | lr: 7.5e-06 | global_batch_size: 8 | global_step: 5 | reduced_train_loss: 6.299 | train_step_timing in s: 0.1278 | consumed_samples: 48 | val_loss: 6.285
0: Training epoch 0, iteration 6/99 | lr: 9e-06 | global_batch_size: 8 | global_step: 6 | reduced_train_loss: 6.282 | train_step_timing in s: 0.127 | consumed_samples: 56 | val_loss: 6.285
0: Training epoch 0, iteration 7/99 | lr: 1.05e-05 | global_batch_size: 8 | global_step: 7 | reduced_train_loss: 6.293 | train_step_timing in s: 0.127 | consumed_samples: 64 | val_loss: 6.285
0: Training epoch 0, iteration 8/99 | lr: 1.2e-05 | global_batch_size: 8 | global_step: 8 | reduced_train_loss: 6.288 | train_step_timing in s: 0.1271 | consumed_samples: 72 | val_loss: 6.285
0: Training epoch 0, iteration 9/99 | lr: 1.35e-05 | global_batch_size: 8 | global_step: 9 | reduced_train_loss: 6.281 | train_step_timing in s: 0.1272 | consumed_samples: 80 | val_loss: 6.285
0: Validation: iteration 1/20
0: Validation: iteration 2/20
0: Validation: iteration 3/20
0: Validation: iteration 4/20
0: Validation: iteration 5/20
0: Validation: iteration 6/20
0: Validation: iteration 7/20
0: Validation: iteration 8/20
0: Validation: iteration 9/20
0: Validation: iteration 10/20
0: Validation: iteration 11/20
0: Validation: iteration 12/20
0: Validation: iteration 13/20
0: Validation: iteration 14/20
0: Validation: iteration 15/20
0: Validation: iteration 16/20
0: Validation: iteration 17/20
0: Validation: iteration 18/20
0: Validation: iteration 19/20
0: Validation: iteration 20/20
0: Training epoch 0, iteration 10/99 | lr: 1.5e-05 | global_batch_size: 8 | global_step: 10 | reduced_train_loss: 6.296 | train_step_timing in s: 0.1277 | consumed_samples: 88 | val_loss: 6.285
0: Training epoch 0, iteration 11/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 11 | reduced_train_loss: 6.283 | train_step_timing in s: 0.127 | consumed_samples: 96 | val_loss: 6.285
0: Training epoch 0, iteration 12/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 12 | reduced_train_loss: 6.265 | train_step_timing in s: 0.127 | consumed_samples: 104 | val_loss: 6.285
0: Training epoch 0, iteration 13/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 13 | reduced_train_loss: 6.275 | train_step_timing in s: 0.1268 | consumed_samples: 112 | val_loss: 6.285
0: Training epoch 0, iteration 14/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 14 | reduced_train_loss: 6.295 | train_step_timing in s: 0.1269 | consumed_samples: 120 | val_loss: 6.285
0: Validation: iteration 1/20
0: Validation: iteration 2/20
0: Validation: iteration 3/20
0: Validation: iteration 4/20
0: Validation: iteration 5/20
0: Validation: iteration 6/20
0: Validation: iteration 7/20
0: Validation: iteration 8/20
0: Validation: iteration 9/20
0: Validation: iteration 10/20
0: Validation: iteration 11/20
0: Validation: iteration 12/20
0: Validation: iteration 13/20
0: Validation: iteration 14/20
0: Validation: iteration 15/20
0: Validation: iteration 16/20
0: Validation: iteration 17/20
0: Validation: iteration 18/20
0: Validation: iteration 19/20
0: Validation: iteration 20/20
0: Training epoch 0, iteration 15/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 15 | reduced_train_loss: 6.287 | train_step_timing in s: 0.1276 | consumed_samples: 128 | val_loss: 6.285
0: Training epoch 0, iteration 16/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 16 | reduced_train_loss: 6.284 | train_step_timing in s: 0.1271 | consumed_samples: 136 | val_loss: 6.285
0: Training epoch 0, iteration 17/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 17 | reduced_train_loss: 6.31 | train_step_timing in s: 0.1267 | consumed_samples: 144 | val_loss: 6.285
0: Training epoch 0, iteration 18/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 18 | reduced_train_loss: 6.28 | train_step_timing in s: 0.1268 | consumed_samples: 152 | val_loss: 6.285
0: Training epoch 0, iteration 19/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 19 | reduced_train_loss: 6.29 | train_step_timing in s: 0.1268 | consumed_samples: 160 | val_loss: 6.285
0: Validation: iteration 1/20
0: Validation: iteration 2/20
0: Validation: iteration 3/20
0: Validation: iteration 4/20
0: Validation: iteration 5/20
0: Validation: iteration 6/20
0: Validation: iteration 7/20
0: Validation: iteration 8/20
0: Validation: iteration 9/20
0: Validation: iteration 10/20
0: Validation: iteration 11/20
0: Validation: iteration 12/20
0: Validation: iteration 13/20
0: Validation: iteration 14/20
0: Validation: iteration 15/20
0: Validation: iteration 16/20
0: Validation: iteration 17/20
0: Validation: iteration 18/20
0: Validation: iteration 19/20
0: Validation: iteration 20/20
0: Training epoch 0, iteration 20/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 20 | reduced_train_loss: 6.274 | train_step_timing in s: 0.1278 | consumed_samples: 168 | val_loss: 6.285
0: Training epoch 0, iteration 21/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 21 | reduced_train_loss: 6.298 | train_step_timing in s: 0.1271 | consumed_samples: 176 | val_loss: 6.285
0: Training epoch 0, iteration 22/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 22 | reduced_train_loss: 6.284 | train_step_timing in s: 0.1269 | consumed_samples: 184 | val_loss: 6.285
0: Training epoch 0, iteration 23/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 23 | reduced_train_loss: 6.281 | train_step_timing in s: 0.1269 | consumed_samples: 192 | val_loss: 6.285
0: Training epoch 0, iteration 24/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 24 | reduced_train_loss: 6.292 | train_step_timing in s: 0.127 | consumed_samples: 200 | val_loss: 6.285
0: Validation: iteration 1/20
0: Validation: iteration 2/20
0: Validation: iteration 3/20
0: Validation: iteration 4/20
0: Validation: iteration 5/20
0: Validation: iteration 6/20
0: Validation: iteration 7/20
0: Validation: iteration 8/20
0: Validation: iteration 9/20
0: Validation: iteration 10/20
0: Validation: iteration 11/20
0: Validation: iteration 12/20
0: Validation: iteration 13/20
0: Validation: iteration 14/20
0: Validation: iteration 15/20
0: Validation: iteration 16/20
0: Validation: iteration 17/20
0: Validation: iteration 18/20
0: Validation: iteration 19/20
0: Validation: iteration 20/20
0: Training epoch 0, iteration 25/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 25 | reduced_train_loss: 6.293 | train_step_timing in s: 0.1274 | consumed_samples: 208 | val_loss: 6.285
0: Training epoch 0, iteration 26/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 26 | reduced_train_loss: 6.274 | train_step_timing in s: 0.1269 | consumed_samples: 216 | val_loss: 6.285
0: Training epoch 0, iteration 27/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 27 | reduced_train_loss: 6.295 | train_step_timing in s: 0.1269 | consumed_samples: 224 | val_loss: 6.285
0: Training epoch 0, iteration 28/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 28 | reduced_train_loss: 6.29 | train_step_timing in s: 0.1268 | consumed_samples: 232 | val_loss: 6.285
0: Training epoch 0, iteration 29/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 29 | reduced_train_loss: 6.26 | train_step_timing in s: 0.1266 | consumed_samples: 240 | val_loss: 6.285
0: Validation: iteration 1/20
0: Validation: iteration 2/20
0: Validation: iteration 3/20
0: Validation: iteration 4/20
0: Validation: iteration 5/20
0: Validation: iteration 6/20
0: Validation: iteration 7/20
0: Validation: iteration 8/20
0: Validation: iteration 9/20
0: Validation: iteration 10/20
0: Validation: iteration 11/20
0: Validation: iteration 12/20
0: Validation: iteration 13/20
0: Validation: iteration 14/20
0: Validation: iteration 15/20
0: Validation: iteration 16/20
0: Validation: iteration 17/20
0: Validation: iteration 18/20
0: Validation: iteration 19/20
0: Validation: iteration 20/20
0: Training epoch 0, iteration 30/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 30 | reduced_train_loss: 6.277 | train_step_timing in s: 0.1276 | consumed_samples: 248 | val_loss: 6.285
0: Training epoch 0, iteration 31/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 31 | reduced_train_loss: 6.27 | train_step_timing in s: 0.1271 | consumed_samples: 256 | val_loss: 6.285
0: Training epoch 0, iteration 32/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 32 | reduced_train_loss: 6.288 | train_step_timing in s: 0.1271 | consumed_samples: 264 | val_loss: 6.285
0: Training epoch 0, iteration 33/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 33 | reduced_train_loss: 6.28 | train_step_timing in s: 0.1268 | consumed_samples: 272 | val_loss: 6.285
0: Training epoch 0, iteration 34/99 | lr: 1.49e-05 | global_batch_size: 8 | global_step: 34 | reduced_train_loss: 6.3 | train_step_timing in s: 0.1268 | consumed_samples: 280 | val_loss: 6.285

Docker Image

No response

System Information

Environment Details:

  • OS: ubuntu
  • CPU: amd64

GPU Details:

  • GPU Model: H100

Additional Context

No response

Metadata

Metadata

Assignees

Labels

Evo2bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions