Skip to content

Commit 7192b5b

Browse files
authored
Fix geneformer training instability bug (#421)
See wandb runs here: https://wandb.ai/clara-discovery/geneformer_bionemo2_timing2 See the results below, we can precisely control whether or not there is a grad norm instability by setting or unsetting the two NVTE env variables. Adding the NVTE env variables to our container is a recent change as well. Based on these results we are unsetting these variables for now. There is not a significant hit to performance by making this change. ## Old run where this was not an issue: <img width="457" alt="Screenshot 2024-11-12 at 9 42 45 AM" src="https://github.com/user-attachments/assets/7571ec4a-7bf1-4f86-901a-4dc983b53149"> ## Representative new run where we see a spike in grad norm <img width="730" alt="Screenshot 2024-11-12 at 9 43 25 AM" src="https://github.com/user-attachments/assets/c9069d1d-3cc7-43e3-93d0-1a3ff07ecfe3"> ## We can make this spike go away by unsetting `NVTE_FUSED_ATTN` and `NVTE_FLASH_ATTN` <img width="731" alt="Screenshot 2024-11-12 at 9 43 44 AM" src="https://github.com/user-attachments/assets/3883383a-e943-4d26-a12a-956f7240bd45"> ## We can introduce this spike on the old image that didn't have these env variables by setting them <img width="728" alt="Screenshot 2024-11-12 at 9 44 16 AM" src="https://github.com/user-attachments/assets/d5daeb16-57be-4e8e-bde6-8b275bf53a46"> ## Example longer/larger batch run that fails with these env variables set <img width="729" alt="Screenshot 2024-11-12 at 9 45 07 AM" src="https://github.com/user-attachments/assets/00cdb307-1863-47e1-b93e-3227cbc7259b"> ## We can stabilize this run by unsetting these env variables <img width="729" alt="Screenshot 2024-11-12 at 9 45 30 AM" src="https://github.com/user-attachments/assets/2cd370e3-5cdc-4385-9294-cdab068d6a8b"> It seems to be relatively recent so this PR is going to test some recent changes to see if any of them is causing this. - [x] Check if the arange change is causing this? - [x] Check if the grad buffer change (should not be enabled) is causing this - [x] bias fusions - [x] garbage collection callback Find out when this worked: - [x] PR 409 right before second perf change and dset change - [x] PR 410 after first perf change, CLI refactor, and wandb fix - [x] PR 404 right before new CLI - [x] PR 362 (2 weeks ago) but restarting job before the gradients start to increase - [x] PR 362 (2 weeks ago) - [x] **worked** https://wandb.ai/clara-discovery/geneformer_bionemo2/runs/0sSIf3tl?nw=nwusernvjstjohn **worked** uses `bionemo2-pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d` - [x] bionemo2-pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d but with `NVTE_FUSED_ATTN=1` and `NVTE_FLASH_ATTN=0` set in my script **did not work ** - [x] bionemo2-pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d but with `NVTE_FUSED_ATTN=1` and `NVTE_FLASH_ATTN=0` `unset` in my script **WORKED!!** - [x] bionemo2-pr419--f2599382e4afaf061c9948628f3f72bb8e233fd6 (most recent PR merged) but manually unsetting `NVTE_FUSED_ATTN=1` and `NVTE_FLASH_ATTN=0` Notes on differences between TOT and `pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d` - `env` doesn't have `NVTE_FUSED*` env settings. Unclear if slurm script adds them properly or not. - `NVTE_FUSED_ATTN` and `NVTE_FLASH_ATTN` are set in `bionemo2-pr373--db2fe9cc240b12bfaf045654fc5350a7b985c9de` for example. - in slurm `--export=ALL` is default and passes all env variables. Perhaps this happens then, so the run where I have those env variables added might fail if those are causing the issue. - Successful run was bs=32 vs 64. I'm running a test now that has the NVTE* settings in the docker script but not in the image. - This was a closed branch, maybe some key changes didn't make it to main. - No `pip freeze` differences pop out that distinguish the branch that passes from the set that fail. - NOTE: See the experiments above around `NVTE_FUSED_ATTN=1` and `NVTE_FLASH_ATTN=0` . I am pretty sure these settings are what cause the training instability in geneformer. Unsetting them works in the old PR and setting them causes that old PR to not work with this explosion of gradients. - Currently I'm rerunning tests on a TOT branch but calling `unset` in my script on those variables so that they are removed from the container env prior to executing the script. If this fixes the TOT training curve I will feel very confident that this is what's going on, and we can focus on purging references to these variables from our docs, other than maybe highlighting how they result in training instability.
1 parent 4ba3595 commit 7192b5b

File tree

7 files changed

+14
-24
lines changed

7 files changed

+14
-24
lines changed

Dockerfile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ RUN <<EOF
166166
EOF
167167

168168
# Transformer engine attention defaults
169-
ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0
169+
# FIXME the following result in unstable training curves even if they are faster
170+
# see https://github.com/NVIDIA/bionemo-framework/pull/421
171+
#ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0
170172

171173
FROM dev AS development
172174

@@ -207,4 +209,6 @@ RUN chmod 777 -R /workspace/bionemo2/
207209

208210
# Transformer engine attention defaults
209211
# We have to declare this again because the devcontainer splits from the release image's base.
210-
ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0
212+
# FIXME the following results in unstable training curves even if faster.
213+
# See https://github.com/NVIDIA/bionemo-framework/pull/421
214+
#ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0

README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,6 @@ export MY_DATA_SOURCE="pbss"
186186

187187
```bash
188188
# The fastest transformer engine environment variables in testing were the following two
189-
export NVTE_FUSED_ATTN=1
190-
export NVTE_FLASH_ATTN=0
191-
192189
TEST_DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source $MY_DATA_SOURCE); \
193190
ESM2_650M_CKPT=$(download_bionemo_data esm2/650m:2.0 --source $MY_DATA_SOURCE); \
194191
python \
@@ -248,9 +245,6 @@ and DataModule types.
248245
> ⚠️ **Warning:** This setup does NO configuration of Weights and Biases. Edit your config JSON and populate it with your WandB details.
249246
250247
```
251-
export NVTE_FUSED_ATTN=1
252-
export NVTE_FLASH_ATTN=0
253-
254248
bionemo-esm2-train \
255249
--data-config-t bionemo.esm2.run.config_models.ESM2DataConfig \
256250
--model-config-t bionemo.esm2.run.config_models.ExposedESM2PretrainConfig \

docs/docs/user-guide/examples/bionemo-esm2/pretrain.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,6 @@ llm.train(
280280
Or simply call `esm2_pretrain.py` directly.
281281
```bash
282282
# Enable fused attention in transformer engine for speed-up
283-
export NVTE_FUSED_ATTN=1
284-
export NVTE_FLASH_ATTN=0
285-
286283
DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source ngc)
287284

288285
python scripts/protein/esm2/esm2_pretrain.py \

scripts/protein/esm2/test_esm2_pretrain.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
9090
result_dir = Path(tmpdir.mkdir("results"))
9191

9292
with megatron_parallel_state_utils.distributed_model_parallel_state():
93-
monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
94-
monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
9593
main(
9694
train_cluster_path=train_cluster_path,
9795
train_database_path=dummy_protein_dataset,
@@ -159,8 +157,6 @@ def test_val_dataloader_in_main_runs_with_limit_val_batches(
159157
result_dir = Path(tmpdir.mkdir("results"))
160158

161159
with megatron_parallel_state_utils.distributed_model_parallel_state():
162-
monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
163-
monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
164160
main(
165161
train_cluster_path=train_cluster_path,
166162
train_database_path=dummy_protein_dataset,
@@ -239,9 +235,6 @@ def test_pretrain_cli(tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inp
239235
# a local copy of the environment
240236
env = dict(**os.environ)
241237
env["MASTER_PORT"] = str(open_port)
242-
env["NVTE_FUSED_ATTN"] = "1"
243-
env["NVTE_FLASH_ATTN"] = "0"
244-
245238
cmd = shlex.split(cmd_str)
246239
result = subprocess.run(
247240
cmd,

sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_stop_and_go.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class TestGeneformerStopAndGo(stop_and_go.StopAndGoHarness):
9797
limit_val_batches: int = 2
9898
lr: float = 1e-4
9999
precision: Literal["16-mixed", "bf16-mixed", "32"] = MODEL_PRECISION
100+
train_val_output_atol: float = 2e-2
100101

101102
@override
102103
@classmethod

sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,7 @@ def configure_model(self, tokenizer: AutoTokenizer) -> MegatronBioBertModelType:
525525
self.num_layers // p_size
526526
) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."
527527

528-
# The local specs all require the standard full attention mask. For transformer engine only the NVTE_FLASH_ATTN=0
529-
# option requires this full attention mask.
528+
# The local specs all require the standard full attention mask.
530529
use_full_attention_mask: bool = "transformer_engine" not in self.biobert_spec_option
531530
do_next_sentence = False
532531
if self.model_cls is None:

sub-packages/bionemo-testing/src/bionemo/testing/harnesses/stop_and_go.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ class StopAndGoHarness(ABC):
106106
limit_val_batches: int
107107
lr: float = 1e-4
108108
precision: Literal["16-mixed", "bf16-mixed", "32"]
109+
train_val_output_atol: float = 1e-3
110+
other_output_atol: float = 1e-4
109111

110112
# class variables that will be setup in setUpClass
111113
tempdir: tempfile.TemporaryDirectory
@@ -336,9 +338,9 @@ def test_stop_and_go_consistency(self, callback_type):
336338
assert interrupted_callback.data, f"No data found for {callback_type}"
337339

338340
if callback_type == testing_callbacks.TrainOutputCallback:
339-
atol = 1e-3
341+
atol = self.train_val_output_atol
340342
else:
341-
atol = 1e-4
343+
atol = self.other_output_atol
342344

343345
recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol)
344346

@@ -388,8 +390,8 @@ def test_stop_and_go_consistency_with_uneven_validation_sizes(self, callback_typ
388390
interrupted_data = interrupted_callback.data[-len(continuous_callback.data) :]
389391

390392
if callback_type == testing_callbacks.ValidOutputCallback:
391-
atol = 1e-3
393+
atol = self.train_val_output_atol
392394
else:
393-
atol = 1e-4
395+
atol = self.other_output_atol
394396

395397
recursive_assert_approx_equal(interrupted_data, continuous_callback.data, atol=atol)

0 commit comments

Comments
 (0)