File tree Expand file tree Collapse file tree 4 files changed +31
-13
lines changed
opengenome2_llama_native_te Expand file tree Collapse file tree 4 files changed +31
-13
lines changed Original file line number Diff line number Diff line change @@ -465,10 +465,16 @@ def __init__(
465465 tp_size = config .tp_size ,
466466 )
467467 if config .tensor_parallel :
468- # If using tensor parallelism, the head weights have already been tied
469- # to the embedding weights. Just set the tensor parallel group for TE.
470- # No parameter quantization either, so no need for weight_mesh.
471- self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
468+ if config .tie_word_embeddings :
469+ # Head weights have already been tied to the embedding weights.
470+ # Just set the tensor parallel group for TE.
471+ # No parameter quantization either, so no need for weight_mesh.
472+ self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
473+ else :
474+ # Head weights are not tied to the embedding weights. Need to
475+ # wrap the LM head weight as a DTensor with TE.
476+ # No parameter quantization either, so no need for weight_mesh.
477+ self .lm_head .set_device_mesh (tp_mesh = self .tp_mesh )
472478
473479 # Initialize weights and apply final processing. Ties weights.
474480 self .post_init ()
Original file line number Diff line number Diff line change @@ -18,7 +18,7 @@ bionemo-framework repository. You can download a zipped directory of this folder
1818
1919| Model | BF16 | FP8<sup >[ 1] </sup > | THD Input Format | FP8 with THD Input Format | MXFP8<sup >[ 2] </sup > | Context Parallelism | Tensor Parallelism |
2020| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ |
21- | [ Llama 3] ( ../../models/llama3/README.md ) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
21+ | [ Llama 3] ( ../../models/llama3/README.md ) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
2222
2323✅: Supported <br />
2424🚧: Under development <br />
Original file line number Diff line number Diff line change @@ -471,10 +471,16 @@ def __init__(
471471 tp_size = config .tp_size ,
472472 )
473473 if config .tensor_parallel :
474- # If using tensor parallelism, the head weights have already been tied
475- # to the embedding weights. Just set the tensor parallel group for TE.
476- # No parameter quantization either, so no need for weight_mesh.
477- self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
474+ if config .tie_word_embeddings :
475+ # Head weights have already been tied to the embedding weights.
476+ # Just set the tensor parallel group for TE.
477+ # No parameter quantization either, so no need for weight_mesh.
478+ self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
479+ else :
480+ # Head weights are not tied to the embedding weights. Need to
481+ # wrap the LM head weight as a DTensor with TE.
482+ # No parameter quantization either, so no need for weight_mesh.
483+ self .lm_head .set_device_mesh (tp_mesh = self .tp_mesh )
478484
479485 # Initialize weights and apply final processing. Ties weights.
480486 self .post_init ()
Original file line number Diff line number Diff line change @@ -471,10 +471,16 @@ def __init__(
471471 tp_size = config .tp_size ,
472472 )
473473 if config .tensor_parallel :
474- # If using tensor parallelism, the head weights have already been tied
475- # to the embedding weights. Just set the tensor parallel group for TE.
476- # No parameter quantization either, so no need for weight_mesh.
477- self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
474+ if config .tie_word_embeddings :
475+ # Head weights have already been tied to the embedding weights.
476+ # Just set the tensor parallel group for TE.
477+ # No parameter quantization either, so no need for weight_mesh.
478+ self .lm_head .set_tensor_parallel_group (self .tp_mesh .get_group ())
479+ else :
480+ # Head weights are not tied to the embedding weights. Need to
481+ # wrap the LM head weight as a DTensor with TE.
482+ # No parameter quantization either, so no need for weight_mesh.
483+ self .lm_head .set_device_mesh (tp_mesh = self .tp_mesh )
478484
479485 # Initialize weights and apply final processing. Ties weights.
480486 self .post_init ()
You can’t perform that action at this time.
0 commit comments