Skip to content

Commit 8b8d05c

Browse files
authored
Add linter tests (#7)
1 parent 0f576ef commit 8b8d05c

File tree

6 files changed

+42
-13
lines changed

6 files changed

+42
-13
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ jobs:
3030
run: |
3131
python -m pip install --upgrade pip
3232
pip install .
33+
pip install black
3334
3435
- name: Run tests
3536
run: |

alignit/visualize.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@ def get_data(index):
2525
outputs=[gr.Image(type="pil", label="Image"), gr.Text(label="Label")],
2626
title="Dataset Image Viewer",
2727
live=True,
28-
).launch(
29-
share=cfg.share,
30-
server_name=cfg.server_name,
31-
server_port=cfg.server_port
32-
)
28+
).launch(share=cfg.share, server_name=cfg.server_name, server_port=cfg.server_port)
3329

3430

3531
if __name__ == "__main__":

tests/test_alignnet.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55

66
def test_alignnet_forward_shapes_cpu():
7-
model = AlignNet(backbone_name="resnet18", backbone_weights=None, use_vector_input=False, output_dim=7)
7+
model = AlignNet(
8+
backbone_name="resnet18",
9+
backbone_weights=None,
10+
use_vector_input=False,
11+
output_dim=7,
12+
)
813
model.eval()
914
x = torch.randn(2, 3, 3, 64, 64) # B=2, N=3 views
1015
with torch.no_grad():
@@ -13,7 +18,12 @@ def test_alignnet_forward_shapes_cpu():
1318

1419

1520
def test_alignnet_with_vector_input():
16-
model = AlignNet(backbone_name="resnet18", backbone_weights=None, use_vector_input=True, output_dim=7)
21+
model = AlignNet(
22+
backbone_name="resnet18",
23+
backbone_weights=None,
24+
use_vector_input=True,
25+
output_dim=7,
26+
)
1727
model.eval()
1828
x = torch.randn(1, 2, 3, 64, 64)
1929
vecs = [torch.randn(5)]
@@ -23,12 +33,18 @@ def test_alignnet_with_vector_input():
2333

2434

2535
def test_alignnet_performance():
26-
model = AlignNet(backbone_name="efficientnet_b0", backbone_weights=None, use_vector_input=True, output_dim=7)
36+
model = AlignNet(
37+
backbone_name="efficientnet_b0",
38+
backbone_weights=None,
39+
use_vector_input=True,
40+
output_dim=7,
41+
)
2742
model.eval()
2843
x = torch.randn(1, 3, 3, 224, 224) # B=1, N=3 views
2944
vecs = [torch.randn(5)]
3045
with torch.no_grad():
3146
import time
47+
3248
start_time = time.time()
3349
for _ in range(10):
3450
y = model(x, vecs)

tests/test_dataset_loader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import os
2-
3-
import pytest
4-
51
from alignit.utils.dataset import load_dataset
62

73

@@ -19,6 +15,7 @@ def fake_load_from_disk(p):
1915

2016
# Patch inside module
2117
import alignit.utils.dataset as ds
18+
2219
ds.load_from_disk = fake_load_from_disk
2320

2421
d = load_dataset(str(tmp_path))
@@ -35,6 +32,7 @@ def fake_hf_load_dataset(name):
3532
return Dummy(name)
3633

3734
import alignit.utils.dataset as ds
35+
3836
ds.hf_load_dataset = fake_hf_load_dataset
3937

4038
d = load_dataset("my-dataset/name")

tests/test_lint.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import subprocess
2+
import sys
3+
from pathlib import Path
4+
5+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
6+
7+
8+
def run(cmd: list[str]):
9+
result = subprocess.run(cmd, cwd=PROJECT_ROOT, capture_output=True, text=True)
10+
if result.returncode != 0:
11+
print(result.stdout)
12+
print(result.stderr, file=sys.stderr)
13+
return result.returncode == 0
14+
15+
16+
def test_black():
17+
result = run([sys.executable, "-m", "black", "--check", "--diff", "--verbose", "."])
18+
assert result, "Black formatting issues found"

tests/test_tfs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ def test_get_pose_str_format():
2121
pose = np.eye(4)
2222
s = get_pose_str(pose)
2323
assert isinstance(s, str)
24-
parts = s.split(',')
24+
parts = s.split(",")
2525
assert len(parts) == 6

0 commit comments

Comments
 (0)