Skip to content

Commit 1a6f2b4

Browse files
committed
Add weight tying logic to LM head, i.e. Lingua does not tie weights.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 45649e3 commit 1a6f2b4

File tree

4 files changed

+31
-13
lines changed

4 files changed

+31
-13
lines changed

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,16 @@ def __init__(
465465
tp_size=config.tp_size,
466466
)
467467
if config.tensor_parallel:
468-
# If using tensor parallelism, the head weights have already been tied
469-
# to the embedding weights. Just set the tensor parallel group for TE.
470-
# No parameter quantization either, so no need for weight_mesh.
471-
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
468+
if config.tie_word_embeddings:
469+
# Head weights have already been tied to the embedding weights.
470+
# Just set the tensor parallel group for TE.
471+
# No parameter quantization either, so no need for weight_mesh.
472+
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
473+
else:
474+
# Head weights are not tied to the embedding weights. Need to
475+
# wrap the LM head weight as a DTensor with TE.
476+
# No parameter quantization either, so no need for weight_mesh.
477+
self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh)
472478

473479
# Initialize weights and apply final processing. Ties weights.
474480
self.post_init()

bionemo-recipes/recipes/llama3_native_te/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ bionemo-framework repository. You can download a zipped directory of this folder
1818

1919
| Model | BF16 | FP8<sup>[1]</sup> | THD Input Format | FP8 with THD Input Format | MXFP8<sup>[2]</sup> | Context Parallelism | Tensor Parallelism |
2020
| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ |
21-
| [Llama 3](../../models/llama3/README.md) ||||||| 🚧 |
21+
| [Llama 3](../../models/llama3/README.md) ||||||| |
2222

2323
✅: Supported <br/>
2424
🚧: Under development <br/>

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,16 @@ def __init__(
471471
tp_size=config.tp_size,
472472
)
473473
if config.tensor_parallel:
474-
# If using tensor parallelism, the head weights have already been tied
475-
# to the embedding weights. Just set the tensor parallel group for TE.
476-
# No parameter quantization either, so no need for weight_mesh.
477-
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
474+
if config.tie_word_embeddings:
475+
# Head weights have already been tied to the embedding weights.
476+
# Just set the tensor parallel group for TE.
477+
# No parameter quantization either, so no need for weight_mesh.
478+
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
479+
else:
480+
# Head weights are not tied to the embedding weights. Need to
481+
# wrap the LM head weight as a DTensor with TE.
482+
# No parameter quantization either, so no need for weight_mesh.
483+
self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh)
478484

479485
# Initialize weights and apply final processing. Ties weights.
480486
self.post_init()

bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,16 @@ def __init__(
471471
tp_size=config.tp_size,
472472
)
473473
if config.tensor_parallel:
474-
# If using tensor parallelism, the head weights have already been tied
475-
# to the embedding weights. Just set the tensor parallel group for TE.
476-
# No parameter quantization either, so no need for weight_mesh.
477-
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
474+
if config.tie_word_embeddings:
475+
# Head weights have already been tied to the embedding weights.
476+
# Just set the tensor parallel group for TE.
477+
# No parameter quantization either, so no need for weight_mesh.
478+
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
479+
else:
480+
# Head weights are not tied to the embedding weights. Need to
481+
# wrap the LM head weight as a DTensor with TE.
482+
# No parameter quantization either, so no need for weight_mesh.
483+
self.lm_head.set_device_mesh(tp_mesh=self.tp_mesh)
478484

479485
# Initialize weights and apply final processing. Ties weights.
480486
self.post_init()

0 commit comments

Comments
 (0)