Skip to content

Commit d541c25

Browse files
committed
lint + precommit
Signed-off-by: jwilber <jwilber@nvidia.com>
1 parent d0847df commit d541c25

File tree

9 files changed

+39
-39
lines changed

9 files changed

+39
-39
lines changed

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _torchrun_prefix(nproc: int) -> list:
5555
return [sys.executable]
5656

5757

58-
def run_extract(cfg: DictConfig, cache_dir: Path) -> None:
58+
def run_extract(cfg: DictConfig, cache_dir: Path) -> None: # noqa: D103
5959
cmd = [
6060
*_torchrun_prefix(cfg.nproc),
6161
str(SCRIPTS_DIR / "extract.py"),
@@ -82,7 +82,7 @@ def run_extract(cfg: DictConfig, cache_dir: Path) -> None:
8282
_run(cmd, f"STEP 1: Extract activations from {cfg.model_path}")
8383

8484

85-
def run_train(cfg: DictConfig, cache_dir: Path, output_dir: Path) -> None:
85+
def run_train(cfg: DictConfig, cache_dir: Path, output_dir: Path) -> None: # noqa: D103
8686
checkpoint_dir = output_dir / "checkpoints"
8787
t = cfg.train
8888

@@ -143,7 +143,7 @@ def run_train(cfg: DictConfig, cache_dir: Path, output_dir: Path) -> None:
143143
_run(cmd, "STEP 2: Train SAE")
144144

145145

146-
def run_eval(cfg: DictConfig, output_dir: Path) -> None:
146+
def run_eval(cfg: DictConfig, output_dir: Path) -> None: # noqa: D103
147147
checkpoint = output_dir / "checkpoints" / "checkpoint_final.pt"
148148
eval_dir = output_dir / "eval"
149149

@@ -176,7 +176,7 @@ def run_eval(cfg: DictConfig, output_dir: Path) -> None:
176176

177177

178178
@hydra.main(version_base=None, config_path="run_configs", config_name="config")
179-
def main(cfg: DictConfig) -> None:
179+
def main(cfg: DictConfig) -> None: # noqa: D103
180180
os.chdir(hydra.utils.get_original_cwd())
181181

182182
print(OmegaConf.to_yaml(cfg))

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/analyze.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@
4545
_CODONFM_TE_DIR = _REPO_ROOT / "recipes" / "codonfm_ptl_te"
4646
sys.path.insert(0, str(_CODONFM_TE_DIR))
4747

48-
from codonfm_sae.data import read_codon_csv
49-
from sae.architectures import TopKSAE
50-
from sae.utils import get_device, set_seed
51-
from src.data.preprocess.codon_sequence import process_item
52-
from src.inference.encodon import EncodonInference
48+
from codonfm_sae.data import read_codon_csv # noqa: E402
49+
from sae.architectures import TopKSAE # noqa: E402
50+
from sae.utils import get_device, set_seed # noqa: E402
51+
from src.data.preprocess.codon_sequence import process_item # noqa: E402
52+
from src.inference.encodon import EncodonInference # noqa: E402
5353

5454

5555
# ── Standard codon usage table (human, per 1000 codons) ──────────────
@@ -189,7 +189,7 @@
189189
}
190190

191191

192-
def parse_args():
192+
def parse_args(): # noqa: D103
193193
p = argparse.ArgumentParser(description="Analyze CodonFM SAE features")
194194
p.add_argument("--checkpoint", type=str, required=True)
195195
p.add_argument("--top-k", type=int, default=None, help="Override top-k (default: read from checkpoint)")
@@ -232,7 +232,7 @@ def parse_args():
232232
return p.parse_args()
233233

234234

235-
def load_sae(checkpoint_path: str, top_k_override: int | None = None) -> TopKSAE:
235+
def load_sae(checkpoint_path: str, top_k_override: int | None = None) -> TopKSAE: # noqa: D103
236236
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
237237
state_dict = ckpt["model_state_dict"]
238238
if any(k.startswith("module.") for k in state_dict):
@@ -533,7 +533,7 @@ def stream_annotations_and_topk(
533533
# ── 3. Auto-interpretation ───────────────────────────────────────────
534534

535535

536-
def get_llm_client(provider: str, model: str = None):
536+
def get_llm_client(provider: str, model: str | None = None):
537537
"""Create LLM client based on provider."""
538538
from sae.autointerp import (
539539
AnthropicClient,
@@ -834,7 +834,7 @@ def build_feature_labels(
834834
# ── Main ─────────────────────────────────────────────────────────────
835835

836836

837-
def main():
837+
def main(): # noqa: D103
838838
args = parse_args()
839839
set_seed(args.seed)
840840
device = args.device or get_device()

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/eval.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141
_CODONFM_TE_DIR = _REPO_ROOT / "recipes" / "codonfm_ptl_te"
4242
sys.path.insert(0, str(_CODONFM_TE_DIR))
4343

44-
from codonfm_sae.data import read_codon_csv
45-
from codonfm_sae.eval import evaluate_codonfm_loss_recovered
46-
from sae.architectures import TopKSAE
47-
from sae.utils import get_device, set_seed
48-
from src.inference.encodon import EncodonInference
44+
from codonfm_sae.data import read_codon_csv # noqa: E402
45+
from codonfm_sae.eval import evaluate_codonfm_loss_recovered # noqa: E402
46+
from sae.architectures import TopKSAE # noqa: E402
47+
from sae.utils import get_device, set_seed # noqa: E402
48+
from src.inference.encodon import EncodonInference # noqa: E402
4949

5050

51-
def parse_args():
51+
def parse_args(): # noqa: D103
5252
p = argparse.ArgumentParser(description="Evaluate CodonFM SAE")
5353

5454
# Checkpoint
@@ -110,7 +110,7 @@ def load_sae_from_checkpoint(checkpoint_path: str, top_k_override: int | None =
110110
return sae
111111

112112

113-
def main():
113+
def main(): # noqa: D103
114114
args = parse_args()
115115
set_seed(args.seed)
116116
device = args.device or get_device()

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/eval_swissprot_f1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@
5151
_CODONFM_TE_DIR = _REPO_ROOT / "recipes" / "codonfm_ptl_te"
5252
sys.path.insert(0, str(_CODONFM_TE_DIR))
5353

54-
from sae.architectures import TopKSAE
55-
from sae.utils import get_device, set_seed
56-
from src.data.preprocess.codon_sequence import process_item
57-
from src.inference.encodon import EncodonInference
54+
from sae.architectures import TopKSAE # noqa: E402
55+
from sae.utils import get_device, set_seed # noqa: E402
56+
from src.data.preprocess.codon_sequence import process_item # noqa: E402
57+
from src.inference.encodon import EncodonInference # noqa: E402
5858

5959

6060
# ── Annotation parsing (adapted from esm2_sae) ─────────────────────────

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/extract.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@
4848
_CODONFM_TE_DIR = _REPO_ROOT / "recipes" / "codonfm_ptl_te"
4949
sys.path.insert(0, str(_CODONFM_TE_DIR))
5050

51-
from codonfm_sae.data import read_codon_csv
52-
from sae.activation_store import ActivationStore, ActivationStoreConfig
53-
from src.data.preprocess.codon_sequence import process_item
54-
from src.inference.encodon import EncodonInference
51+
from codonfm_sae.data import read_codon_csv # noqa: E402
52+
from sae.activation_store import ActivationStore, ActivationStoreConfig # noqa: E402
53+
from src.data.preprocess.codon_sequence import process_item # noqa: E402
54+
from src.inference.encodon import EncodonInference # noqa: E402
5555

5656

57-
def parse_args():
57+
def parse_args(): # noqa: D103
5858
p = argparse.ArgumentParser(description="Extract CodonFM layer activations")
5959
p.add_argument(
6060
"--csv-path", type=str, required=True, help="Path to CSV with DNA sequences (auto-detects 'seq'/'cds' column)"
@@ -138,7 +138,7 @@ def _merge_rank_stores(cache_path: Path, world_size: int, metadata: dict) -> Non
138138
print(f"Merged {world_size} rank stores: {total_samples:,} tokens, {shard_idx} shards")
139139

140140

141-
def main():
141+
def main(): # noqa: D103
142142
args = parse_args()
143143
torch.manual_seed(args.seed)
144144

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/launch_dashboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _filter_and_copy_parquet(src: Path, dst: Path, live_ids: set):
6060
return n_before, len(df)
6161

6262

63-
def main():
63+
def main(): # noqa: D103
6464
p = argparse.ArgumentParser(description="Launch codon SAE dashboard")
6565
p.add_argument(
6666
"--data-dir",

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from sae.utils import get_device, set_seed
4646

4747

48-
def parse_args():
48+
def parse_args(): # noqa: D103
4949
p = argparse.ArgumentParser(
5050
description="Train SAE from cached CodonFM activations",
5151
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@@ -110,7 +110,7 @@ def parse_args():
110110
return p.parse_args()
111111

112112

113-
def build_sae(args, input_dim: int) -> torch.nn.Module:
113+
def build_sae(args, input_dim: int) -> torch.nn.Module: # noqa: D103
114114
hidden_dim = input_dim * args.expansion_factor
115115

116116
if args.model_type == "topk":
@@ -133,7 +133,7 @@ def build_sae(args, input_dim: int) -> torch.nn.Module:
133133
raise ValueError(f"Unknown model type: {args.model_type}")
134134

135135

136-
def build_training_config(args, device: str) -> TrainingConfig:
136+
def build_training_config(args, device: str) -> TrainingConfig: # noqa: D103
137137
return TrainingConfig(
138138
lr=args.lr,
139139
n_epochs=args.n_epochs,
@@ -150,7 +150,7 @@ def build_training_config(args, device: str) -> TrainingConfig:
150150
)
151151

152152

153-
def build_wandb_config(args) -> WandbConfig:
153+
def build_wandb_config(args) -> WandbConfig: # noqa: D103
154154
return WandbConfig(
155155
enabled=args.wandb_enabled,
156156
project=args.wandb_project,
@@ -161,11 +161,11 @@ def build_wandb_config(args) -> WandbConfig:
161161
)
162162

163163

164-
def build_parallel_config(args) -> ParallelConfig:
164+
def build_parallel_config(args) -> ParallelConfig: # noqa: D103
165165
return ParallelConfig(dp_size=args.dp_size)
166166

167167

168-
def main():
168+
def main(): # noqa: D103
169169
args = parse_args()
170170

171171
set_seed(args.seed)

bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/src/codonfm_sae/data/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,5 @@ class CodonRecord:
2626
metadata: Dict[str, Any] = field(default_factory=dict)
2727

2828
@property
29-
def num_codons(self) -> int:
29+
def num_codons(self) -> int: # noqa: D102
3030
return len(self.sequence) // 3

bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/scripts/launch_dashboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _filter_and_copy_parquet(src: Path, dst: Path, live_ids: set):
5959
return n_before, len(df)
6060

6161

62-
def main():
62+
def main(): # noqa: D103
6363
p = argparse.ArgumentParser(description="Launch ESM2 SAE dashboard")
6464
p.add_argument(
6565
"--data-dir",

0 commit comments

Comments
 (0)