4343_CODONFM_TE_DIR = _REPO_ROOT / "recipes" / "codonfm_ptl_te"
4444sys .path .insert (0 , str (_CODONFM_TE_DIR ))
4545
46- from codonfm_sae .data import read_codon_csv
47- from sae .analysis import compute_feature_stats , compute_feature_umap , save_feature_atlas
48- from sae .architectures import TopKSAE
49- from sae .utils import get_device , set_seed
50- from src .data .preprocess .codon_sequence import process_item
51- from src .inference .encodon import EncodonInference
46+ from codonfm_sae .data import read_codon_csv # noqa: E402
47+ from sae .analysis import compute_feature_stats , compute_feature_umap , save_feature_atlas # noqa: E402
48+ from sae .architectures import TopKSAE # noqa: E402
49+ from sae .utils import get_device , set_seed # noqa: E402
50+ from src .data .preprocess .codon_sequence import process_item # noqa: E402
51+ from src .inference .encodon import EncodonInference # noqa: E402
5252
5353
54- def parse_args ():
54+ def parse_args (): # noqa: D103
5555 p = argparse .ArgumentParser (description = "Generate CodonFM SAE dashboard data" )
5656 p .add_argument ("--checkpoint" , type = str , required = True , help = "Path to SAE checkpoint .pt file" )
5757 p .add_argument ("--top-k" , type = int , default = None , help = "Override top-k (default: read from checkpoint)" )
@@ -75,7 +75,7 @@ def parse_args():
7575 return p .parse_args ()
7676
7777
78- def load_sae_from_checkpoint (checkpoint_path : str , top_k_override : int | None = None ) -> TopKSAE :
78+ def load_sae_from_checkpoint (checkpoint_path : str , top_k_override : int | None = None ) -> TopKSAE : # noqa: D103
7979 ckpt = torch .load (checkpoint_path , map_location = "cpu" , weights_only = False )
8080 state_dict = ckpt ["model_state_dict" ]
8181 if any (k .startswith ("module." ) for k in state_dict ):
@@ -158,7 +158,6 @@ def extract_activations_3d(
158158
159159 # Pad to same seq_len across batches
160160 max_len = max (e .shape [1 ] for e in all_embeddings )
161- hidden_dim = all_embeddings [0 ].shape [2 ]
162161
163162 padded_emb = []
164163 padded_masks = []
@@ -182,9 +181,9 @@ def export_codon_features_parquet(
182181 output_dir : Path ,
183182 n_examples : int = 6 ,
184183 device : str = "cuda" ,
185- records : list = None ,
186- variant_delta_map : dict = None ,
187- precomputed_max_acts : torch .Tensor = None ,
184+ records : list | None = None ,
185+ variant_delta_map : dict | None = None ,
186+ precomputed_max_acts : torch .Tensor | None = None ,
188187):
189188 """Export per-codon feature activations for dashboard.
190189
@@ -358,7 +357,7 @@ def compute_variant_analysis(
358357 activations : torch .Tensor ,
359358 masks : torch .Tensor ,
360359 device : str = "cuda" ,
361- score_column : str = None ,
360+ score_column : str | None = None ,
362361) -> dict :
363362 """Compute per-feature variant analysis with multi-score, local deltas, and distribution metrics.
364363
@@ -562,7 +561,7 @@ def compute_variant_analysis(
562561
563562 # ── Trinuc context distribution per feature ──────────────────────
564563 # Only variant rows (ref rows have no trinuc_context)
565- unique_trinucs = sorted (set ( t for t in trinuc_contexts if t ) )
564+ unique_trinucs = sorted ({ t for t in trinuc_contexts if t } )
566565 trinuc_to_idx = {t : i for i , t in enumerate (unique_trinucs )}
567566 n_trinucs = len (unique_trinucs )
568567 print (f" { n_trinucs } unique trinucleotide contexts" )
@@ -595,7 +594,7 @@ def compute_variant_analysis(
595594
596595 # ── Gene distribution per feature ────────────────────────────────
597596 # Uses all sequences (every row has a gene)
598- unique_genes = sorted (set ( g for g in genes if g ) )
597+ unique_genes = sorted ({ g for g in genes if g } )
599598 gene_to_idx = {g : i for i , g in enumerate (unique_genes )}
600599 n_genes_total = len (unique_genes )
601600 print (f" { n_genes_total } unique genes" )
@@ -764,7 +763,7 @@ def compute_variant_analysis(
764763 }
765764
766765
767- def main ():
766+ def main (): # noqa: D103
768767 args = parse_args ()
769768 set_seed (args .seed )
770769 device = args .device or get_device ()
@@ -775,7 +774,6 @@ def main():
775774
776775 # 1. Load SAE
777776 sae = load_sae_from_checkpoint (args .checkpoint , top_k_override = args .top_k )
778- n_features = sae .hidden_dim
779777
780778 # 2. Load Encodon
781779 print (f"\n Loading Encodon from { args .model_path } ..." )
0 commit comments