Skip to content

Commit 668a213

Browse files
committed
Add minimal testing for GANOrchestrator and implement simple discriminator fixture
1 parent b2aae07 commit 668a213

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed

tests/engine/conftest.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,43 @@ def forward(self, x):
112112
return MultiOutputConv()
113113

114114

115+
@pytest.fixture
116+
def simple_discriminator():
117+
"""
118+
Simple discriminator model for GAN testing.
119+
Takes concatenated input/target stack (B, 6, H, W) -> outputs score (B, 1)
120+
Uses conv + global average pooling + linear layer.
121+
"""
122+
class SimpleDiscriminator(nn.Module):
123+
def __init__(self):
124+
super().__init__()
125+
self.conv = nn.Conv2d(
126+
in_channels=6, # stacked input + target
127+
out_channels=16,
128+
kernel_size=3,
129+
padding=1,
130+
bias=True
131+
)
132+
self.pool = nn.AdaptiveAvgPool2d(1) # Global average pooling
133+
self.fc = nn.Linear(16, 1) # Output single score
134+
135+
def forward(self, x):
136+
x = self.conv(x)
137+
x = torch.relu(x)
138+
x = self.pool(x) # (B, 16, 1, 1)
139+
x = x.flatten(1) # (B, 16)
140+
x = self.fc(x) # (B, 1)
141+
return x
142+
143+
return SimpleDiscriminator()
144+
145+
146+
@pytest.fixture
147+
def random_stack():
148+
"""Random stack tensor (batch=2, channels=6, height=8, width=8) for discriminator."""
149+
return torch.randn(2, 6, 8, 8)
150+
151+
115152
@pytest.fixture
116153
def sample_inputs():
117154
"""Create sample inputs for loss computation."""
@@ -208,3 +245,31 @@ def forward_pass_context_eval(forward_group, random_input, random_target, torch_
208245
inputs=random_input.to(torch_device),
209246
targets=random_target.to(torch_device),
210247
)
248+
249+
250+
@pytest.fixture
251+
def disc_optimizer(simple_discriminator):
252+
"""Create an Adam optimizer for the discriminator model."""
253+
import torch.optim as optim
254+
return optim.Adam(simple_discriminator.parameters(), lr=1e-3)
255+
256+
257+
@pytest.fixture
258+
def discriminator_forward_group(simple_discriminator, disc_optimizer, torch_device):
259+
"""Create a DiscriminatorForwardGroup with the simple discriminator and optimizer."""
260+
from virtual_stain_flow.engine.forward_groups import DiscriminatorForwardGroup
261+
return DiscriminatorForwardGroup(
262+
discriminator=simple_discriminator,
263+
optimizer=disc_optimizer,
264+
device=torch_device,
265+
)
266+
267+
268+
@pytest.fixture
269+
def gan_orchestrator(forward_group, discriminator_forward_group):
270+
"""Create a GANOrchestrator with generator and discriminator forward groups."""
271+
from virtual_stain_flow.engine.orchestrators import GANOrchestrator
272+
return GANOrchestrator(
273+
generator_fg=forward_group,
274+
discriminator_fg=discriminator_forward_group,
275+
)

tests/engine/test_forward_group.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from virtual_stain_flow.engine.forward_groups import (
77
AbstractForwardGroup,
8-
GeneratorForwardGroup
8+
GeneratorForwardGroup,
9+
DiscriminatorForwardGroup
910
)
1011
from virtual_stain_flow.engine.names import INPUTS, TARGETS, PREDS, GENERATOR_MODEL
1112

@@ -128,3 +129,52 @@ def test_forward_output_arity_mismatch(self, multi_output_model, random_input, r
128129

129130
with pytest.raises(ValueError, match="Model returned 2 outputs.*output_keys expects 1"):
130131
forward_group(train=False, inputs=random_input, targets=random_target)
132+
133+
134+
class TestDiscriminatorForwardGroup:
135+
"""Test DiscriminatorForwardGroup functionality."""
136+
137+
def test_forward_train_mode(self, simple_discriminator, random_stack):
138+
"""Test that discriminator is set to train mode when train=True."""
139+
forward_group = DiscriminatorForwardGroup(
140+
device=torch.device("cpu"),
141+
discriminator=simple_discriminator
142+
)
143+
144+
ctx = forward_group(train=True, stack=random_stack)
145+
146+
assert forward_group.model.training is True
147+
assert ctx["p"].requires_grad is True
148+
149+
def test_forward_eval_mode(self, simple_discriminator, random_stack):
150+
"""Test that discriminator is set to eval mode when train=False."""
151+
forward_group = DiscriminatorForwardGroup(
152+
device=torch.device("cpu"),
153+
discriminator=simple_discriminator
154+
)
155+
156+
ctx = forward_group(train=False, stack=random_stack)
157+
158+
assert forward_group.model.training is False
159+
assert ctx["p"].requires_grad is False
160+
161+
def test_optimizer_zero_grad(self, simple_discriminator, disc_optimizer, random_stack):
162+
"""Test that optimizer.zero_grad() is called when train=True."""
163+
forward_group = DiscriminatorForwardGroup(
164+
device=torch.device("cpu"),
165+
discriminator=simple_discriminator,
166+
optimizer=disc_optimizer
167+
)
168+
169+
# Manually create some gradients
170+
dummy_loss = sum(p.sum() for p in forward_group.model.parameters())
171+
dummy_loss.backward()
172+
173+
# Check that gradients exist
174+
assert any(p.grad is not None for p in forward_group.model.parameters())
175+
176+
# Forward should zero gradients
177+
_ = forward_group(train=True, stack=random_stack)
178+
179+
# Gradients should be None (set_to_none=True)
180+
assert all(p.grad is None for p in forward_group.model.parameters())

tests/engine/test_orchestrator.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Tests for GANOrchestrator."""
2+
3+
import torch
4+
5+
from virtual_stain_flow.engine.names import INPUTS, TARGETS, PREDS
6+
7+
8+
class TestGANOrchestrator:
9+
"""Test GANOrchestrator functionality."""
10+
11+
def test_discriminator_forward(self, gan_orchestrator, random_input, random_target):
12+
"""Test that _discriminator_forward produces correct context with real and fake stacks."""
13+
ctx = gan_orchestrator._discriminator_forward(
14+
train=False,
15+
inputs=random_input,
16+
targets=random_target
17+
)
18+
19+
# Check that generator outputs are present
20+
assert INPUTS in ctx
21+
assert TARGETS in ctx
22+
assert PREDS in ctx
23+
24+
# Check that discriminator outputs for real and fake are present
25+
assert "real_stack" in ctx
26+
assert "fake_stack" in ctx
27+
assert "p_real_as_real" in ctx
28+
assert "p_fake_as_real" in ctx
29+
30+
# Verify shapes
31+
batch_size = random_input.shape[0]
32+
assert ctx["p_real_as_real"].shape[0] == batch_size
33+
assert ctx["p_fake_as_real"].shape[0] == batch_size
34+
35+
# Verify real_stack is concatenation of inputs and targets
36+
expected_real_stack = torch.cat([ctx[INPUTS], ctx[TARGETS]], dim=1)
37+
assert torch.allclose(ctx["real_stack"], expected_real_stack)
38+
39+
# Verify fake_stack is concatenation of inputs and preds
40+
expected_fake_stack = torch.cat([ctx[INPUTS], ctx[PREDS]], dim=1)
41+
assert torch.allclose(ctx["fake_stack"], expected_fake_stack)
42+
43+
def test_generator_forward(self, gan_orchestrator, random_input, random_target):
44+
"""Test that _generator_forward produces correct context with generator outputs and discriminator score."""
45+
ctx = gan_orchestrator._generator_forward(
46+
train=False,
47+
inputs=random_input,
48+
targets=random_target
49+
)
50+
51+
# Check that generator outputs are present
52+
assert INPUTS in ctx
53+
assert TARGETS in ctx
54+
assert PREDS in ctx
55+
56+
# Check that discriminator score for fake is present
57+
assert "p_fake_as_real" in ctx
58+
59+
# Verify shapes
60+
batch_size = random_input.shape[0]
61+
assert ctx[PREDS].shape[0] == batch_size
62+
assert ctx["p_fake_as_real"].shape[0] == batch_size
63+
assert ctx["p_fake_as_real"].shape[1] == 1 # Single score output

0 commit comments

Comments
 (0)