Commit 7192b5b
authored
Fix geneformer training instability bug (#421)
See wandb runs here:
https://wandb.ai/clara-discovery/geneformer_bionemo2_timing2
See the results below, we can precisely control whether or not there is
a grad norm instability by setting or unsetting the two NVTE env
variables. Adding the NVTE env variables to our container is a recent
change as well. Based on these results we are unsetting these variables
for now. There is not a significant hit to performance by making this
change.
## Old run where this was not an issue:
<img width="457" alt="Screenshot 2024-11-12 at 9 42 45 AM"
src="https://github.com/user-attachments/assets/7571ec4a-7bf1-4f86-901a-4dc983b53149">
## Representative new run where we see a spike in grad norm
<img width="730" alt="Screenshot 2024-11-12 at 9 43 25 AM"
src="https://github.com/user-attachments/assets/c9069d1d-3cc7-43e3-93d0-1a3ff07ecfe3">
## We can make this spike go away by unsetting `NVTE_FUSED_ATTN` and
`NVTE_FLASH_ATTN`
<img width="731" alt="Screenshot 2024-11-12 at 9 43 44 AM"
src="https://github.com/user-attachments/assets/3883383a-e943-4d26-a12a-956f7240bd45">
## We can introduce this spike on the old image that didn't have these
env variables by setting them
<img width="728" alt="Screenshot 2024-11-12 at 9 44 16 AM"
src="https://github.com/user-attachments/assets/d5daeb16-57be-4e8e-bde6-8b275bf53a46">
## Example longer/larger batch run that fails with these env variables
set
<img width="729" alt="Screenshot 2024-11-12 at 9 45 07 AM"
src="https://github.com/user-attachments/assets/00cdb307-1863-47e1-b93e-3227cbc7259b">
## We can stabilize this run by unsetting these env variables
<img width="729" alt="Screenshot 2024-11-12 at 9 45 30 AM"
src="https://github.com/user-attachments/assets/2cd370e3-5cdc-4385-9294-cdab068d6a8b">
It seems to be relatively recent so this PR is going to test some recent
changes to see if any of them is causing this.
- [x] Check if the arange change is causing this?
- [x] Check if the grad buffer change (should not be enabled) is causing
this
- [x] bias fusions
- [x] garbage collection callback
Find out when this worked:
- [x] PR 409 right before second perf change and dset change
- [x] PR 410 after first perf change, CLI refactor, and wandb fix
- [x] PR 404 right before new CLI
- [x] PR 362 (2 weeks ago) but restarting job before the gradients start
to increase
- [x] PR 362 (2 weeks ago)
- [x] **worked**
https://wandb.ai/clara-discovery/geneformer_bionemo2/runs/0sSIf3tl?nw=nwusernvjstjohn
**worked** uses
`bionemo2-pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d`
- [x] bionemo2-pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d but with
`NVTE_FUSED_ATTN=1` and `NVTE_FLASH_ATTN=0` set in my script **did not
work **
- [x] bionemo2-pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d but with
`NVTE_FUSED_ATTN=1` and `NVTE_FLASH_ATTN=0` `unset` in my script
**WORKED!!**
- [x] bionemo2-pr419--f2599382e4afaf061c9948628f3f72bb8e233fd6 (most
recent PR merged) but manually unsetting `NVTE_FUSED_ATTN=1` and
`NVTE_FLASH_ATTN=0`
Notes on differences between TOT and
`pr312--136b1889fc390d9dad04f077b32b8fbecf50e25d`
- `env` doesn't have `NVTE_FUSED*` env settings. Unclear if slurm script
adds them properly or not.
- `NVTE_FUSED_ATTN` and `NVTE_FLASH_ATTN` are set in
`bionemo2-pr373--db2fe9cc240b12bfaf045654fc5350a7b985c9de` for example.
- in slurm `--export=ALL` is default and passes all env variables.
Perhaps this happens then, so the run where I have those env variables
added might fail if those are causing the issue.
- Successful run was bs=32 vs 64. I'm running a test now that has the
NVTE* settings in the docker script but not in the image.
- This was a closed branch, maybe some key changes didn't make it to
main.
- No `pip freeze` differences pop out that distinguish the branch that
passes from the set that fail.
- NOTE: See the experiments above around `NVTE_FUSED_ATTN=1` and
`NVTE_FLASH_ATTN=0` . I am pretty sure these settings are what cause the
training instability in geneformer. Unsetting them works in the old PR
and setting them causes that old PR to not work with this explosion of
gradients.
- Currently I'm rerunning tests on a TOT branch but calling `unset` in
my script on those variables so that they are removed from the container
env prior to executing the script. If this fixes the TOT training curve
I will feel very confident that this is what's going on, and we can
focus on purging references to these variables from our docs, other than
maybe highlighting how they result in training instability.1 parent 4ba3595 commit 7192b5b
File tree
7 files changed
+14
-24
lines changed- docs/docs/user-guide/examples/bionemo-esm2
- scripts/protein/esm2
- sub-packages
- bionemo-geneformer/tests/bionemo/geneformer
- bionemo-llm/src/bionemo/llm/model/biobert
- bionemo-testing/src/bionemo/testing/harnesses
7 files changed
+14
-24
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
166 | 166 | | |
167 | 167 | | |
168 | 168 | | |
169 | | - | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
170 | 172 | | |
171 | 173 | | |
172 | 174 | | |
| |||
207 | 209 | | |
208 | 210 | | |
209 | 211 | | |
210 | | - | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
186 | 186 | | |
187 | 187 | | |
188 | 188 | | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | 189 | | |
193 | 190 | | |
194 | 191 | | |
| |||
248 | 245 | | |
249 | 246 | | |
250 | 247 | | |
251 | | - | |
252 | | - | |
253 | | - | |
254 | 248 | | |
255 | 249 | | |
256 | 250 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
280 | 280 | | |
281 | 281 | | |
282 | 282 | | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | 283 | | |
287 | 284 | | |
288 | 285 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
93 | | - | |
94 | | - | |
95 | 93 | | |
96 | 94 | | |
97 | 95 | | |
| |||
159 | 157 | | |
160 | 158 | | |
161 | 159 | | |
162 | | - | |
163 | | - | |
164 | 160 | | |
165 | 161 | | |
166 | 162 | | |
| |||
239 | 235 | | |
240 | 236 | | |
241 | 237 | | |
242 | | - | |
243 | | - | |
244 | | - | |
245 | 238 | | |
246 | 239 | | |
247 | 240 | | |
| |||
Lines changed: 1 addition & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
97 | 97 | | |
98 | 98 | | |
99 | 99 | | |
| 100 | + | |
100 | 101 | | |
101 | 102 | | |
102 | 103 | | |
| |||
Lines changed: 1 addition & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
525 | 525 | | |
526 | 526 | | |
527 | 527 | | |
528 | | - | |
529 | | - | |
| 528 | + | |
530 | 529 | | |
531 | 530 | | |
532 | 531 | | |
| |||
Lines changed: 6 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
106 | 106 | | |
107 | 107 | | |
108 | 108 | | |
| 109 | + | |
| 110 | + | |
109 | 111 | | |
110 | 112 | | |
111 | 113 | | |
| |||
336 | 338 | | |
337 | 339 | | |
338 | 340 | | |
339 | | - | |
| 341 | + | |
340 | 342 | | |
341 | | - | |
| 343 | + | |
342 | 344 | | |
343 | 345 | | |
344 | 346 | | |
| |||
388 | 390 | | |
389 | 391 | | |
390 | 392 | | |
391 | | - | |
| 393 | + | |
392 | 394 | | |
393 | | - | |
| 395 | + | |
394 | 396 | | |
395 | 397 | | |
0 commit comments