Skip to content

Commit d0847df

Browse files
committed
lint dashboard
Signed-off-by: jwilber <jwilber@nvidia.com>
1 parent ef4fa1a commit d0847df

File tree

1 file changed

+15
-17
lines changed
  • bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts

1 file changed

+15
-17
lines changed

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/dashboard.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@
4343
_CODONFM_TE_DIR = _REPO_ROOT / "recipes" / "codonfm_ptl_te"
4444
sys.path.insert(0, str(_CODONFM_TE_DIR))
4545

46-
from codonfm_sae.data import read_codon_csv
47-
from sae.analysis import compute_feature_stats, compute_feature_umap, save_feature_atlas
48-
from sae.architectures import TopKSAE
49-
from sae.utils import get_device, set_seed
50-
from src.data.preprocess.codon_sequence import process_item
51-
from src.inference.encodon import EncodonInference
46+
from codonfm_sae.data import read_codon_csv # noqa: E402
47+
from sae.analysis import compute_feature_stats, compute_feature_umap, save_feature_atlas # noqa: E402
48+
from sae.architectures import TopKSAE # noqa: E402
49+
from sae.utils import get_device, set_seed # noqa: E402
50+
from src.data.preprocess.codon_sequence import process_item # noqa: E402
51+
from src.inference.encodon import EncodonInference # noqa: E402
5252

5353

54-
def parse_args():
54+
def parse_args(): # noqa: D103
5555
p = argparse.ArgumentParser(description="Generate CodonFM SAE dashboard data")
5656
p.add_argument("--checkpoint", type=str, required=True, help="Path to SAE checkpoint .pt file")
5757
p.add_argument("--top-k", type=int, default=None, help="Override top-k (default: read from checkpoint)")
@@ -75,7 +75,7 @@ def parse_args():
7575
return p.parse_args()
7676

7777

78-
def load_sae_from_checkpoint(checkpoint_path: str, top_k_override: int | None = None) -> TopKSAE:
78+
def load_sae_from_checkpoint(checkpoint_path: str, top_k_override: int | None = None) -> TopKSAE: # noqa: D103
7979
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
8080
state_dict = ckpt["model_state_dict"]
8181
if any(k.startswith("module.") for k in state_dict):
@@ -158,7 +158,6 @@ def extract_activations_3d(
158158

159159
# Pad to same seq_len across batches
160160
max_len = max(e.shape[1] for e in all_embeddings)
161-
hidden_dim = all_embeddings[0].shape[2]
162161

163162
padded_emb = []
164163
padded_masks = []
@@ -182,9 +181,9 @@ def export_codon_features_parquet(
182181
output_dir: Path,
183182
n_examples: int = 6,
184183
device: str = "cuda",
185-
records: list = None,
186-
variant_delta_map: dict = None,
187-
precomputed_max_acts: torch.Tensor = None,
184+
records: list | None = None,
185+
variant_delta_map: dict | None = None,
186+
precomputed_max_acts: torch.Tensor | None = None,
188187
):
189188
"""Export per-codon feature activations for dashboard.
190189
@@ -358,7 +357,7 @@ def compute_variant_analysis(
358357
activations: torch.Tensor,
359358
masks: torch.Tensor,
360359
device: str = "cuda",
361-
score_column: str = None,
360+
score_column: str | None = None,
362361
) -> dict:
363362
"""Compute per-feature variant analysis with multi-score, local deltas, and distribution metrics.
364363
@@ -562,7 +561,7 @@ def compute_variant_analysis(
562561

563562
# ── Trinuc context distribution per feature ──────────────────────
564563
# Only variant rows (ref rows have no trinuc_context)
565-
unique_trinucs = sorted(set(t for t in trinuc_contexts if t))
564+
unique_trinucs = sorted({t for t in trinuc_contexts if t})
566565
trinuc_to_idx = {t: i for i, t in enumerate(unique_trinucs)}
567566
n_trinucs = len(unique_trinucs)
568567
print(f" {n_trinucs} unique trinucleotide contexts")
@@ -595,7 +594,7 @@ def compute_variant_analysis(
595594

596595
# ── Gene distribution per feature ────────────────────────────────
597596
# Uses all sequences (every row has a gene)
598-
unique_genes = sorted(set(g for g in genes if g))
597+
unique_genes = sorted({g for g in genes if g})
599598
gene_to_idx = {g: i for i, g in enumerate(unique_genes)}
600599
n_genes_total = len(unique_genes)
601600
print(f" {n_genes_total} unique genes")
@@ -764,7 +763,7 @@ def compute_variant_analysis(
764763
}
765764

766765

767-
def main():
766+
def main(): # noqa: D103
768767
args = parse_args()
769768
set_seed(args.seed)
770769
device = args.device or get_device()
@@ -775,7 +774,6 @@ def main():
775774

776775
# 1. Load SAE
777776
sae = load_sae_from_checkpoint(args.checkpoint, top_k_override=args.top_k)
778-
n_features = sae.hidden_dim
779777

780778
# 2. Load Encodon
781779
print(f"\nLoading Encodon from {args.model_path}...")

0 commit comments

Comments
 (0)