Skip to content

ai code review and fix#1480

Merged
pstjohn merged 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/claude-review
Feb 27, 2026
Merged

ai code review and fix#1480
pstjohn merged 1 commit intoNVIDIA:mainfrom
pstjohn:pstjohn/claude-review

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Feb 21, 2026

Misc. semi-automated fixes adding documentation and some additional tests throughout the recipes directories

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Mixtral model support with TransformerEngine optimizations and comprehensive documentation
    • Introduced context-parallel tensor processing helpers across model implementations
    • Added token dropout support for different input formats (BSHD and THD)
    • New test utilities for checkpoint pruning and scheduler validation
  • Refactoring

    • Updated state transformation API across convert modules for consistency
    • Simplified checkpoint and dataset utilities
    • Refactored collators to use modular helper functions
  • Documentation

    • Expanded module docstrings describing state dict transformation system
    • Added documentation for attention input formats across models
    • Updated README references and configuration documentation
  • Tests

    • Added Mixtral export validation tests
    • New checkpoint pruning and scheduler tests with comprehensive edge-case coverage
    • Updated test base classes for consistency across model implementations

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 21, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 21, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

@pstjohn pstjohn force-pushed the pstjohn/claude-review branch from 14b1640 to 773d581 Compare February 24, 2026 22:09
@pstjohn pstjohn marked this pull request as ready for review February 24, 2026 22:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
bionemo-recipes/recipes/esm2_native_te/checkpoint.py (1)

590-590: ⚠️ Potential issue | 🟠 Major

torch.load called without weights_only=True — arbitrary code execution risk.

Loading a pickle-backed .pt file without weights_only=True allows a crafted or corrupted checkpoint to execute arbitrary Python during deserialization. The llama3_native_te/checkpoint.py equivalent (line 444) already passes weights_only=True; this file should match it.

🛡️ Proposed fix
-    dataloader_state = torch.load(dataloader_path)
+    dataloader_state = torch.load(dataloader_path, weights_only=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/checkpoint.py` at line 590, The call
dataloader_state = torch.load(dataloader_path) in checkpoint.py is unsafe;
update the torch.load invocation to pass weights_only=True (matching the
llama3_native_te/checkpoint.py usage) so it deserializes only tensor data and
avoids executing arbitrary pickle code—locate the dataloader_state assignment in
the file and add the weights_only=True argument to torch.load.
♻️ Duplicate comments (3)
bionemo-recipes/models/llama3/tests/common/__init__.py (1)

21-21: Same docstring formatting issue as in the ESM2 __init__.py.

Two bullet items collapsed onto one line. See the fix proposed in the ESM2 review.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/llama3/tests/common/__init__.py` at line 21, The
module docstring in bionemo-recipes/models/llama3/tests/common/__init__.py has
two bullet items merged onto one line; update the top-level docstring so each
bullet is on its own line (separate the entries for BaseModelTest and
TestTolerances into distinct list items) and ensure the docstring follows the
same multiline bullet formatting used in the ESM2 __init__.py example.
bionemo-recipes/models/llama3/collator.py (1)

733-868: Duplicate of helpers already reviewed in models/esm2/collator.py.

These are identical implementations. See the code duplication comment on the ESM2 collator review — consider a shared utility module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/llama3/collator.py` around lines 733 - 868, The
functions _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd are
duplicates of helpers in models/esm2/collator.py; refactor by extracting these
helpers into a shared utility module (e.g., a new module like
models/shared/collators.py or similar) and replace the local definitions with
imports and usage of the shared functions; update the current file to import
_find_seq_dim, _process_tensor_thd, and _process_tensor_bshd from that shared
module and remove the duplicate definitions here, ensuring any referenced
symbols (seq_len, slice_sizes, cu_seqlens_padded, cp_rank, total_slices,
cp_world_size) match the shared helper signatures.
bionemo-recipes/models/mixtral/tests/common/__init__.py (1)

21-21: Same docstring formatting issue as in the ESM2 and Llama3 __init__.py files.

Two bullet items collapsed onto one line. See the fix proposed in the ESM2 review.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/common/__init__.py` at line 21, The
module docstring has two bullet items collapsed onto one line; update the
module-level docstring in __init__.py so each item is on its own line (use a
newline and proper bullet prefix for "BaseModelTest: Base test class with all
common test methods" and "TestTolerances: Dataclass for model-specific numerical
tolerances")—locate the docstring near the top of the file and adjust the
formatting to match the fixed style used in the ESM2/Llama3 __init__.py files.
🧹 Nitpick comments (8)
bionemo-recipes/models/mixtral/README.md (2)

11-19: Table cell padding is not mdformat-compliant.

mdformat normalises table cells to use minimal spacing (single space padding), but the current table uses wide right-padding to visually align columns. Running mdformat will reformat these cells, producing diff noise in future PRs.

Run mdformat bionemo-recipes/models/mixtral/README.md to normalise formatting. As per coding guidelines: "Use mdformat for Markdown formatting."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/README.md` around lines 11 - 19, The Markdown
table in README.md uses wide cell padding instead of mdformat's minimal
single-space padding; run mdformat on the file (e.g., mdformat
bionemo-recipes/models/mixtral/README.md) or manually reduce each table cell to
single-space padding so the table rows (the lines containing "Feature | Support"
and the subsequent pipe-separated rows like "**FP8** | ✅ Supported...") conform
to mdformat normalization and avoid future diff noise.

87-101: export.py is not mentioned in the Developer Guide.

Per project requirements, each model recipe must ship an export.py for Hugging Face Hub bundling. The README does not reference it, leaving users without guidance on how to package and publish the TE model. Consider adding a brief "Exporting to Hugging Face Hub" subsection that calls out export.py and its usage.

Based on learnings: "Models in bionemo-recipes/models/ must include: … export script (export.py) for Hugging Face Hub bundling."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/README.md` around lines 87 - 101, Add a new
"Exporting to Hugging Face Hub" subsection to the Developer Guide that documents
the required export.py script: state that the model directory must include
export.py, explain how to run it to create the HF bundle (e.g., run export.py
from the model directory or via python export.py with any required args), note
any dependencies or env vars needed for HF upload, and link to
recipes_local_test.py as the local test step before publishing; reference the
filename export.py and the test runner recipes_local_test.py so maintainers can
locate and update the script.
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)

101-108: Consider adding reduce_dtype=torch.float32 to the else branch for gradient stability.

The FSDP2 MixedPrecisionPolicy signature is MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), confirming output_dtype is a valid field — the if branch is correct.

For the else branch (pure BF16, use_fp32_master_weights=False), MixedPrecisionPolicy() defaults leave reduce_dtype=None, meaning gradient all-reduces also happen in BF16. Gradients can vary significantly from rank to rank, and reducing in float32 can be critical for numerics. The if branch already does this correctly with reduce_dtype=torch.float32; omitting it in the else branch may cause training instability, especially for larger models.

💡 Suggested improvement for the `else` branch
     else:
-        mp_policy = MixedPrecisionPolicy()
+        mp_policy = MixedPrecisionPolicy(reduce_dtype=torch.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py` around lines 101 -
108, The else branch leaves MixedPrecisionPolicy.reduce_dtype as None so
gradient all-reduces happen in BF16, which can destabilize training; update the
else branch that sets mp_policy (when args.use_fp32_master_weights is False) to
instantiate MixedPrecisionPolicy with reduce_dtype=torch.float32 (e.g.,
MixedPrecisionPolicy(reduce_dtype=torch.float32, param_dtype=torch.bfloat16,
output_dtype=torch.bfloat16 or just set reduce_dtype alongside the default
call)) so gradient reductions occur in FP32 while keeping the rest of the BF16
policy.
bionemo-recipes/models/mixtral/convert.py (1)

65-74: Add Args and Returns sections to complete the Google-style docstring.

The updated body text is well-written, but the docstring omits the Args and Returns sections that are present in every other function in this file, leaving num_experts and the return type undocumented.

📝 Proposed fix
 def _make_merge_experts_fn(num_experts: int):
     """Create a merge function with the correct number of named parameters.
 
     The state.py transform system maps function parameter names to source dict keys by inspecting
     the function signature. When ``source_key`` is a tuple, it pairs each tuple element with the
     corresponding named parameter via ``{param: source_key[i]}``. This means ``*args`` style
     parameters do not work -- the system cannot map positional varargs to specific source keys.
 
     Since the number of experts is dynamic (varies per model config), we use ``exec()`` to generate
     a function with exactly ``num_experts`` named parameters (weight0, weight1, ..., weightN-1).
+
+    Args:
+        num_experts: Number of experts; determines the count of named parameters in the generated function.
+
+    Returns:
+        A callable ``merge_experts(weight0, weight1, ..., weightN-1)`` that stacks its inputs along a new
+        leading dimension using ``torch.stack``.
     """

As per coding guidelines, "Ensure all Python files follow Google-style docstrings (pydocstyle convention)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/convert.py` around lines 65 - 74, The
docstring for the function create_merge_fn (the factory that generates a merge
function with num_experts named parameters) is missing Google-style "Args" and
"Returns" sections; update the docstring to add an Args section documenting
num_experts (type: int, meaning the number of expert-weight parameters to
generate) and any other parameters, and add a Returns section describing the
returned callable (e.g., a function taking weight0..weightN-1 and returning the
merged result, include its type/signature). Keep wording consistent with other
functions in the file and follow the existing Google-style formatting used
elsewhere.
bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py (1)

89-103: Good coverage — consider adding a negative save_every_n_steps edge-case assertion.

The current suite covers the documented contract thoroughly. One untested, albeit well-defined, edge case is save_every_n_steps < 0: the existing guard save_every_n_steps > 0 makes it return False, but an explicit assertion would document the intended behaviour and guard against future regressions if the guard is ever changed.

✏️ Suggested addition
     # save_every_n_steps=0 should never save
     assert should_save_checkpoint(step=10, save_every_n_steps=0) is False
+
+    # Negative save_every_n_steps should never save
+    assert should_save_checkpoint(step=10, save_every_n_steps=-1) is False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py`
around lines 89 - 103, Add an explicit test assertion for the edge case
save_every_n_steps < 0 in the test_should_save_checkpoint function: verify that
should_save_checkpoint(step=10, save_every_n_steps=-1) returns False to document
and lock in the intended behavior; update the same test (or add a new one
nearby) referencing the should_save_checkpoint function so future changes to the
guard save_every_n_steps > 0 will be caught by CI.
bionemo-recipes/models/esm2/collator.py (2)

845-849: Minor: _process_tensor_bshd divisibility check is incomplete but safe in practice.

The error message says the sequence length "must be divisible by" total_chunks, but the check (chunk_size == 0) only catches when seq_len < total_chunks. If seq_len % total_chunks != 0, the remainder is silently dropped. This is safe because the upstream pad_thd_sequences_for_cp guarantees divisibility, but the error message is misleading. Consider adding a strict check:

🔧 Optional stricter validation
-    if chunk_size == 0:
+    if seq_len % total_chunks != 0:
         raise ValueError(
             f"Sequence length {seq_len} must be divisible by {total_chunks} "
             f"(2 * cp_world_size) for BSHD context parallelism"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/esm2/collator.py` around lines 845 - 849, In
_process_tensor_bshd, replace the incomplete divisibility check that only tests
chunk_size == 0 with a strict modulus check (if seq_len % total_chunks != 0) and
raise a ValueError including seq_len and total_chunks in the message; reference
the relationship to pad_thd_sequences_for_cp in the message or a comment to
clarify why this should normally not trigger. This ensures remainder cases are
caught instead of silently dropping tokens and makes the error message accurate.

733-868: Significant code duplication across three model collators: extract shared helpers to reduce maintenance burden.

_find_seq_dim, _process_tensor_thd, and _process_tensor_bshd are duplicated verbatim in bionemo-recipes/models/esm2/collator.py, bionemo-recipes/models/llama3/collator.py, and bionemo-recipes/models/mixtral/collator.py (~113 lines total). While recipe duplication is justified by the self-containment guideline, the duplication across three model collators violates DRY principles. Consider extracting these helpers into a shared utility module under bionemo-recipes/models/ (e.g., bionemo-recipes/models/common/cp_utils.py) and importing from there in all three model collators.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/esm2/collator.py` around lines 733 - 868, Extract the
duplicated helpers _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd
into a single shared utility module (e.g., cp_utils) and replace the verbatim
copies in each collator with imports from that module; specifically, move the
three functions as-is into the new module (preserving signatures and torch
usage), update the collator files to import _find_seq_dim, _process_tensor_thd,
and _process_tensor_bshd, and ensure any device/typing references still resolve
(add necessary imports like torch and typing in the new module) so behavior and
exceptions remain unchanged.
bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py (1)

241-290: Consider reusing _create_inference_params to avoid duplication.

The TE beam-search test re-implements the same KV-cache setup; reusing the helper keeps it consistent.

♻️ Suggested refactor
-        past_key_values = HFInferenceParams(
-            max_batch_size=2 * num_beams,
-            max_sequence_length=256,
-            num_heads_kv=config.num_key_value_heads,
-            head_dim_k=config.hidden_size // config.num_attention_heads,
-            dtype=torch.bfloat16,
-            qkv_format="thd",
-            max_ctx_len=256,
-        )
-        for layer_number in range(1, config.num_hidden_layers + 1):
-            past_key_values.allocate_memory(layer_number)
+        past_key_values = self._create_inference_params(
+            config,
+            batch_size=2,
+            max_seq_len=256,
+            num_beams=num_beams,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py` around lines
241 - 290, The test test_te_mixtral_model_generate_with_cache_beam_search
duplicates KV-cache setup; replace the manual HFInferenceParams construction and
loop with a call to the existing helper _create_inference_params (or whatever
public helper is present) to build and allocate past_key_values for the model
config, then use that returned past_key_values in the generate() call; ensure
you pass the same args (dtype, qkv_format, max_ctx_len, max_batch_size, etc.)
into _create_inference_params so behavior remains identical and remove the
manual for-loop that calls past_key_values.allocate_memory.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/models/esm2/tests/common/__init__.py`:
- Line 21: The module docstring currently has the two bullet points for
BaseModelTest and TestTolerances collapsed onto a single line; update the
top-level docstring in this tests/common __init__.py so each item is its own
bullet on its own line (e.g., "- BaseModelTest: Base test class..." newline "-
TestTolerances: Dataclass for model-specific numerical tolerances"), and apply
the same fix to the identical docstrings in the llama3 and mixtral __init__.py
files to ensure consistent rendering.

In `@bionemo-recipes/models/mixtral/collator.py`:
- Around line 811-868: The function _process_tensor_bshd currently uses floor
division for chunk_size which silently drops tail tokens; add an explicit
divisibility guard after computing total_chunks: if seq_len % total_chunks != 0
raise a ValueError (with a clear message referencing seq_len and total_chunks)
so we fail fast rather than truncating data; keep the existing check for
chunk_size==0 but add this new check before computing chunk indices and slicing.

In `@bionemo-recipes/models/mixtral/README.md`:
- Around line 46-48: The README example disables KV-cache by passing
use_cache=False to model_te.generate, contradicting the "KV-cache inference" ✅
claim; update the quick-start snippet to exercise KV-cache by removing the
use_cache override or setting use_cache=True when calling model_te.generate
(inside the with torch.no_grad() block) so the example actually uses the model's
KV-cache inference path and matches the Feature Support table.
- Around line 81-85: The "Validating Converted Models" section is circular and
lacks runnable comparison commands; update it in the README by either (a)
including a concrete, copy-pastable snippet under the "Inference Examples" /
"Validating Converted Models" headings that shows how to load the baseline
Hugging Face model and the converted model and compute/compare logits and loss
(e.g., commands or Python call sequence to run inference for both models and
diff their outputs), or (b) add a clear link and brief instruction pointing
readers to the golden-value test test_modeling_mixtral.py explaining exactly
which test function/assertion to run and how to interpret its outputs; reference
the "Inference Examples" section and the test file name test_modeling_mixtral.py
so readers can locate the code to run.
- Around line 25-50: The README code examples use bare imports like "from
convert import convert_mixtral_hf_to_te" and "from modeling_mixtral_te import
..." which will raise ModuleNotFoundError unless the current working directory
is bionemo-recipes/models/mixtral; update the snippets to include a one-line
preamble instructing users to either run the snippet from that directory (e.g.,
"cd bionemo-recipes/models/mixtral") or to set up the PYTHONPATH/sys.path or
install the package via the documented workflow (pip install -r
requirements.txt) before running, and add that same short note to every block
that uses convert or modeling_mixtral_te so users won't hit silent import
failures.

In `@bionemo-recipes/recipes/esm2_native_te/collator.py`:
- Around line 811-868: The function _process_tensor_bshd currently can drop tail
tokens when seq_len is not divisible by (2 * cp_world_size); add an explicit
divisibility check after computing total_chunks (or chunk_size) and if seq_len %
total_chunks != 0 raise a ValueError with a clear message (e.g. "Sequence length
{seq_len} must be divisible by {total_chunks} (2 * cp_world_size) for BSHD
context parallelism") so the function fails fast instead of silently dropping
tokens.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_scheduler.py`:
- Line 21: The test imports the recipe-local scheduler module using an absolute
import; replace the top-level import in test_scheduler.py so it explicitly
references the local scheduler module by using a relative import (e.g. import
get_linear_schedule_with_warmup from ..scheduler) so that
get_linear_schedule_with_warmup is resolved from the recipe's scheduler module
rather than relying on conftest.py sys.path manipulation.

---

Outside diff comments:
In `@bionemo-recipes/recipes/esm2_native_te/checkpoint.py`:
- Line 590: The call dataloader_state = torch.load(dataloader_path) in
checkpoint.py is unsafe; update the torch.load invocation to pass
weights_only=True (matching the llama3_native_te/checkpoint.py usage) so it
deserializes only tensor data and avoids executing arbitrary pickle code—locate
the dataloader_state assignment in the file and add the weights_only=True
argument to torch.load.

---

Duplicate comments:
In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 733-868: The functions _find_seq_dim, _process_tensor_thd, and
_process_tensor_bshd are duplicates of helpers in models/esm2/collator.py;
refactor by extracting these helpers into a shared utility module (e.g., a new
module like models/shared/collators.py or similar) and replace the local
definitions with imports and usage of the shared functions; update the current
file to import _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd from
that shared module and remove the duplicate definitions here, ensuring any
referenced symbols (seq_len, slice_sizes, cu_seqlens_padded, cp_rank,
total_slices, cp_world_size) match the shared helper signatures.

In `@bionemo-recipes/models/llama3/tests/common/__init__.py`:
- Line 21: The module docstring in
bionemo-recipes/models/llama3/tests/common/__init__.py has two bullet items
merged onto one line; update the top-level docstring so each bullet is on its
own line (separate the entries for BaseModelTest and TestTolerances into
distinct list items) and ensure the docstring follows the same multiline bullet
formatting used in the ESM2 __init__.py example.

In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Line 21: The module docstring has two bullet items collapsed onto one line;
update the module-level docstring in __init__.py so each item is on its own line
(use a newline and proper bullet prefix for "BaseModelTest: Base test class with
all common test methods" and "TestTolerances: Dataclass for model-specific
numerical tolerances")—locate the docstring near the top of the file and adjust
the formatting to match the fixed style used in the ESM2/Llama3 __init__.py
files.

---

Nitpick comments:
In `@bionemo-recipes/models/esm2/collator.py`:
- Around line 845-849: In _process_tensor_bshd, replace the incomplete
divisibility check that only tests chunk_size == 0 with a strict modulus check
(if seq_len % total_chunks != 0) and raise a ValueError including seq_len and
total_chunks in the message; reference the relationship to
pad_thd_sequences_for_cp in the message or a comment to clarify why this should
normally not trigger. This ensures remainder cases are caught instead of
silently dropping tokens and makes the error message accurate.
- Around line 733-868: Extract the duplicated helpers _find_seq_dim,
_process_tensor_thd, and _process_tensor_bshd into a single shared utility
module (e.g., cp_utils) and replace the verbatim copies in each collator with
imports from that module; specifically, move the three functions as-is into the
new module (preserving signatures and torch usage), update the collator files to
import _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd, and ensure
any device/typing references still resolve (add necessary imports like torch and
typing in the new module) so behavior and exceptions remain unchanged.

In `@bionemo-recipes/models/mixtral/convert.py`:
- Around line 65-74: The docstring for the function create_merge_fn (the factory
that generates a merge function with num_experts named parameters) is missing
Google-style "Args" and "Returns" sections; update the docstring to add an Args
section documenting num_experts (type: int, meaning the number of expert-weight
parameters to generate) and any other parameters, and add a Returns section
describing the returned callable (e.g., a function taking weight0..weightN-1 and
returning the merged result, include its type/signature). Keep wording
consistent with other functions in the file and follow the existing Google-style
formatting used elsewhere.

In `@bionemo-recipes/models/mixtral/README.md`:
- Around line 11-19: The Markdown table in README.md uses wide cell padding
instead of mdformat's minimal single-space padding; run mdformat on the file
(e.g., mdformat bionemo-recipes/models/mixtral/README.md) or manually reduce
each table cell to single-space padding so the table rows (the lines containing
"Feature | Support" and the subsequent pipe-separated rows like "**FP8** | ✅
Supported...") conform to mdformat normalization and avoid future diff noise.
- Around line 87-101: Add a new "Exporting to Hugging Face Hub" subsection to
the Developer Guide that documents the required export.py script: state that the
model directory must include export.py, explain how to run it to create the HF
bundle (e.g., run export.py from the model directory or via python export.py
with any required args), note any dependencies or env vars needed for HF upload,
and link to recipes_local_test.py as the local test step before publishing;
reference the filename export.py and the test runner recipes_local_test.py so
maintainers can locate and update the script.

In `@bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py`:
- Around line 241-290: The test
test_te_mixtral_model_generate_with_cache_beam_search duplicates KV-cache setup;
replace the manual HFInferenceParams construction and loop with a call to the
existing helper _create_inference_params (or whatever public helper is present)
to build and allocate past_key_values for the model config, then use that
returned past_key_values in the generate() call; ensure you pass the same args
(dtype, qkv_format, max_ctx_len, max_batch_size, etc.) into
_create_inference_params so behavior remains identical and remove the manual
for-loop that calls past_key_values.allocate_memory.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py`:
- Around line 89-103: Add an explicit test assertion for the edge case
save_every_n_steps < 0 in the test_should_save_checkpoint function: verify that
should_save_checkpoint(step=10, save_every_n_steps=-1) returns False to document
and lock in the intended behavior; update the same test (or add a new one
nearby) referencing the should_save_checkpoint function so future changes to the
guard save_every_n_steps > 0 will be caught by CI.

In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py`:
- Around line 101-108: The else branch leaves MixedPrecisionPolicy.reduce_dtype
as None so gradient all-reduces happen in BF16, which can destabilize training;
update the else branch that sets mp_policy (when args.use_fp32_master_weights is
False) to instantiate MixedPrecisionPolicy with reduce_dtype=torch.float32
(e.g., MixedPrecisionPolicy(reduce_dtype=torch.float32,
param_dtype=torch.bfloat16, output_dtype=torch.bfloat16 or just set reduce_dtype
alongside the default call)) so gradient reductions occur in FP32 while keeping
the rest of the BF16 policy.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6a60786 and 773d581.

📒 Files selected for processing (35)
  • bionemo-recipes/models/amplify/src/amplify/state.py
  • bionemo-recipes/models/esm2/README.md
  • bionemo-recipes/models/esm2/collator.py
  • bionemo-recipes/models/esm2/convert.py
  • bionemo-recipes/models/esm2/modeling_esm_te.py
  • bionemo-recipes/models/esm2/state.py
  • bionemo-recipes/models/esm2/tests/common/__init__.py
  • bionemo-recipes/models/llama3/collator.py
  • bionemo-recipes/models/llama3/convert.py
  • bionemo-recipes/models/llama3/modeling_llama_te.py
  • bionemo-recipes/models/llama3/state.py
  • bionemo-recipes/models/llama3/tests/common/__init__.py
  • bionemo-recipes/models/mixtral/README.md
  • bionemo-recipes/models/mixtral/collator.py
  • bionemo-recipes/models/mixtral/convert.py
  • bionemo-recipes/models/mixtral/modeling_mixtral_te.py
  • bionemo-recipes/models/mixtral/state.py
  • bionemo-recipes/models/mixtral/tests/common/__init__.py
  • bionemo-recipes/models/mixtral/tests/test_export.py
  • bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py
  • bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/README.md
  • bionemo-recipes/recipes/esm2_native_te/checkpoint.py
  • bionemo-recipes/recipes/esm2_native_te/collator.py
  • bionemo-recipes/recipes/esm2_native_te/dataset.py
  • bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py
  • bionemo-recipes/recipes/esm2_native_te/tests/test_scheduler.py
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
  • bionemo-recipes/recipes/esm2_peft_te/collator.py
  • bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/llama3_native_te/checkpoint.py
  • bionemo-recipes/recipes/llama3_native_te/collator.py
  • bionemo-recipes/recipes/llama3_native_te/dataset.py
  • bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Copy link
Collaborator

@trvachov trvachov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit and one real comment about future security alerts.

@pstjohn
Copy link
Collaborator Author

pstjohn commented Feb 25, 2026

@coderabbiai resolve

@jomitchellnv
Copy link
Collaborator

jomitchellnv commented Feb 26, 2026

Release notes make it clear this MR covers four buckets:
1. Mixtral + TE support / model features (context-parallel helpers, token dropout)
2. Refactors (state transform API, checkpoint/dataset utils, collators)
3. Docs (state-dict system + input format docs + README/config)
4. Tests (Mixtral export validation, checkpoint/scheduler tests, base class updates)

Would you be open to splitting this into 2–4 MRs along those lines? As-is it’s difficult to review in one pass, and splitting would align better with our CONTRIBUTING guidance (“Make sure your PR does one thing/Have a clear answer to ‘What does this PR do?’")

@jomitchellnv jomitchellnv self-requested a review February 26, 2026 23:53
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/claude-review branch from 0748173 to e22b2c5 Compare February 27, 2026 14:54
@NVIDIA NVIDIA deleted a comment from jomitchellnv Feb 27, 2026
@NVIDIA NVIDIA deleted a comment from coderabbitai bot Feb 27, 2026
@pstjohn pstjohn enabled auto-merge February 27, 2026 14:59
@pstjohn pstjohn added this pull request to the merge queue Feb 27, 2026
Merged via the queue into NVIDIA:main with commit 6f259a8 Feb 27, 2026
22 checks passed
@pstjohn pstjohn deleted the pstjohn/claude-review branch February 27, 2026 15:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants