Skip to content

Commit c927b7e

Browse files
test: fix async tests wrapped by patch decorators
1 parent e62de08 commit c927b7e

File tree

1 file changed

+124
-122
lines changed

1 file changed

+124
-122
lines changed

tests/guardrails/test_model_engine.py

Lines changed: 124 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -195,201 +195,203 @@ def test_client_initially_none(self):
195195
class TestModelEngineLifecycle:
196196
"""Test the ModelEngine start() and stop() client lifecycle."""
197197

198-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "key"})
199198
@pytest.mark.asyncio
200199
async def test_start_stop_lifecycle(self):
201200
"""start() creates the client, stop() tears it down to None."""
202-
engine = ModelEngine(_make_model())
203-
assert engine._client is None
204-
assert engine._running is False
205-
await engine.start()
206-
assert engine._client is not None
207-
assert engine._running is True
208-
await engine.stop()
209-
assert engine._client is None
210-
assert engine._running is False
201+
# NOTE: use a context manager instead of the decorator, so pytest can
202+
# correctly detect and run the coroutine test across Python versions.
203+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "key"}):
204+
engine = ModelEngine(_make_model())
205+
assert engine._client is None
206+
assert engine._running is False
207+
await engine.start()
208+
assert engine._client is not None
209+
assert engine._running is True
210+
await engine.stop()
211+
assert engine._client is None
212+
assert engine._running is False
211213

212-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "key"})
213214
@pytest.mark.asyncio
214215
async def test_start_is_idempotent(self):
215216
"""Calling start() twice reuses the same client instance."""
216-
engine = ModelEngine(_make_model())
217-
await engine.start()
218-
first_client = engine._client
219-
await engine.start() # should not create a new client
220-
assert engine._client is first_client
221-
await engine.stop()
217+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "key"}):
218+
engine = ModelEngine(_make_model())
219+
await engine.start()
220+
first_client = engine._client
221+
await engine.start() # should not create a new client
222+
assert engine._client is first_client
223+
await engine.stop()
222224

223-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "key"})
224225
@pytest.mark.asyncio
225226
async def test_stop_when_no_client_is_noop(self):
226227
"""stop() without a prior start() does not raise."""
227-
engine = ModelEngine(_make_model())
228-
await engine.stop() # should not raise
229-
assert engine._running is False
228+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "key"}):
229+
engine = ModelEngine(_make_model())
230+
await engine.stop() # should not raise
231+
assert engine._running is False
230232

231-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "key"})
232233
@pytest.mark.asyncio
233234
async def test_stop_is_idempotent(self):
234235
"""Calling stop() twice does not raise."""
235-
engine = ModelEngine(_make_model())
236-
await engine.start()
237-
await engine.stop()
238-
await engine.stop() # second stop is a no-op
239-
assert engine._running is False
236+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "key"}):
237+
engine = ModelEngine(_make_model())
238+
await engine.start()
239+
await engine.stop()
240+
await engine.stop() # second stop is a no-op
241+
assert engine._running is False
240242

241243

242244
class TestModelEngineContextManager:
243245
"""Test async context manager calls start/stop correctly."""
244246

245-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "key"})
246247
@pytest.mark.asyncio
247248
async def test_context_manager_calls_start_and_stop(self):
248249
"""async with calls start() on enter and stop() on exit."""
249-
engine = ModelEngine(_make_model())
250-
assert engine._running is False
251-
async with engine as eng:
252-
assert eng is engine
253-
assert engine._running is True
254-
assert engine._client is not None
255-
assert engine._running is False
256-
assert engine._client is None
250+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "key"}):
251+
engine = ModelEngine(_make_model())
252+
assert engine._running is False
253+
async with engine as eng:
254+
assert eng is engine
255+
assert engine._running is True
256+
assert engine._client is not None
257+
assert engine._running is False
258+
assert engine._client is None
257259

258260

259261
class TestModelEngineCall:
260262
"""Test ModelEngine.call() HTTP request construction and error handling."""
261263

262-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"})
263264
@pytest.mark.asyncio
264265
async def test_successful_call(self):
265266
"""Successful call returns parsed JSON and posts to correct URL with headers."""
266-
model = _make_model()
267-
engine = ModelEngine(model)
267+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"}):
268+
model = _make_model()
269+
engine = ModelEngine(model)
268270

269-
expected_response = {"choices": [{"message": {"role": "assistant", "content": "Hello!"}}]}
271+
expected_response = {"choices": [{"message": {"role": "assistant", "content": "Hello!"}}]}
270272

271-
mock_response = AsyncMock()
272-
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
273-
mock_response.status = 200
274-
mock_response.json = AsyncMock(return_value=expected_response)
273+
mock_response = AsyncMock()
274+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
275+
mock_response.status = 200
276+
mock_response.json = AsyncMock(return_value=expected_response)
275277

276-
mock_client = AsyncMock()
277-
mock_client.post = MagicMock(return_value=mock_response)
278-
mock_client.closed = False
278+
mock_client = AsyncMock()
279+
mock_client.post = MagicMock(return_value=mock_response)
280+
mock_client.closed = False
279281

280-
engine._client = mock_client
281-
engine._running = True
282+
engine._client = mock_client
283+
engine._running = True
282284

283-
messages = [{"role": "user", "content": "Hi"}]
284-
result = await engine.call(messages)
285-
assert result == expected_response
285+
messages = [{"role": "user", "content": "Hi"}]
286+
result = await engine.call(messages)
287+
assert result == expected_response
286288

287-
# Verify correct URL
288-
call_args = mock_client.post.call_args
289-
assert _CHAT_COMPLETIONS_ENDPOINT in call_args[0][0]
289+
# Verify correct URL
290+
call_args = mock_client.post.call_args
291+
assert _CHAT_COMPLETIONS_ENDPOINT in call_args[0][0]
290292

291-
expected_url = _ENGINE_BASE_URLS[model.engine] + "/v1/chat/completions"
292-
expected_json = {"messages": messages, "model": model.model}
293-
expected_headers = {"Content-Type": "application/json", "Authorization": "Bearer test-key"}
294-
mock_client.post.assert_called_once_with(expected_url, json=expected_json, headers=expected_headers)
293+
expected_url = _ENGINE_BASE_URLS[model.engine] + "/v1/chat/completions"
294+
expected_json = {"messages": messages, "model": model.model}
295+
expected_headers = {"Content-Type": "application/json", "Authorization": "Bearer test-key"}
296+
mock_client.post.assert_called_once_with(expected_url, json=expected_json, headers=expected_headers)
295297

296-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"})
297298
@pytest.mark.asyncio
298299
async def test_call_includes_model_name_and_messages_in_body(self):
299300
"""Request body contains model name, messages, and extra kwargs."""
300-
engine = ModelEngine(_make_model(model="my-llm"))
301+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"}):
302+
engine = ModelEngine(_make_model(model="my-llm"))
301303

302-
mock_response = AsyncMock()
303-
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
304-
mock_response.status = 200
305-
mock_response.json = AsyncMock(return_value={"choices": [{"message": {"content": "ok"}}]})
304+
mock_response = AsyncMock()
305+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
306+
mock_response.status = 200
307+
mock_response.json = AsyncMock(return_value={"choices": [{"message": {"content": "ok"}}]})
306308

307-
mock_client = AsyncMock()
308-
mock_client.post = MagicMock(return_value=mock_response)
309-
mock_client.closed = False
310-
engine._client = mock_client
311-
engine._running = True
309+
mock_client = AsyncMock()
310+
mock_client.post = MagicMock(return_value=mock_response)
311+
mock_client.closed = False
312+
engine._client = mock_client
313+
engine._running = True
312314

313-
messages = [{"role": "user", "content": "Hello"}]
314-
await engine.call(messages, temperature=0.7)
315+
messages = [{"role": "user", "content": "Hello"}]
316+
await engine.call(messages, temperature=0.7)
315317

316-
call_kwargs = mock_client.post.call_args
317-
body = call_kwargs[1]["json"]
318-
assert body["model"] == "my-llm"
319-
assert body["messages"] == messages
320-
assert body["temperature"] == 0.7
318+
call_kwargs = mock_client.post.call_args
319+
body = call_kwargs[1]["json"]
320+
assert body["model"] == "my-llm"
321+
assert body["messages"] == messages
322+
assert body["temperature"] == 0.7
321323

322-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"})
323324
@pytest.mark.asyncio
324325
async def test_call_without_api_key_omits_auth_header(self):
325326
"""No Authorization header when api_key is None."""
326-
engine = ModelEngine(_make_model())
327-
engine.api_key = None # simulate no API key
327+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"}):
328+
engine = ModelEngine(_make_model())
329+
engine.api_key = None # simulate no API key
328330

329-
mock_response = AsyncMock()
330-
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
331-
mock_response.status = 200
332-
mock_response.json = AsyncMock(return_value={"choices": [{"message": {"content": "ok"}}]})
331+
mock_response = AsyncMock()
332+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
333+
mock_response.status = 200
334+
mock_response.json = AsyncMock(return_value={"choices": [{"message": {"content": "ok"}}]})
333335

334-
mock_client = AsyncMock()
335-
mock_client.post = MagicMock(return_value=mock_response)
336-
mock_client.closed = False
337-
engine._client = mock_client
338-
engine._running = True
336+
mock_client = AsyncMock()
337+
mock_client.post = MagicMock(return_value=mock_response)
338+
mock_client.closed = False
339+
engine._client = mock_client
340+
engine._running = True
339341

340-
await engine.call([{"role": "user", "content": "Hi"}])
342+
await engine.call([{"role": "user", "content": "Hi"}])
341343

342-
call_kwargs = mock_client.post.call_args
343-
headers = call_kwargs[1]["headers"]
344-
assert "Authorization" not in headers
344+
call_kwargs = mock_client.post.call_args
345+
headers = call_kwargs[1]["headers"]
346+
assert "Authorization" not in headers
345347

346-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"})
347348
@pytest.mark.asyncio
348349
async def test_call_http_error_raises_model_engine_error(self):
349350
"""HTTP 4xx/5xx raises ModelEngineError with status and model name."""
350-
engine = ModelEngine(_make_model())
351+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"}):
352+
engine = ModelEngine(_make_model())
351353

352-
mock_response = AsyncMock()
353-
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
354-
mock_response.status = 400
355-
mock_response.text = AsyncMock(return_value='{"error": "bad request"}')
354+
mock_response = AsyncMock()
355+
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
356+
mock_response.status = 400
357+
mock_response.text = AsyncMock(return_value='{"error": "bad request"}')
356358

357-
mock_client = AsyncMock()
358-
mock_client.post = MagicMock(return_value=mock_response)
359-
mock_client.closed = False
360-
engine._client = mock_client
361-
engine._running = True
359+
mock_client = AsyncMock()
360+
mock_client.post = MagicMock(return_value=mock_response)
361+
mock_client.closed = False
362+
engine._client = mock_client
363+
engine._running = True
362364

363-
with pytest.raises(ModelEngineError) as exc_info:
364-
await engine.call([{"role": "user", "content": "Hi"}])
365+
with pytest.raises(ModelEngineError) as exc_info:
366+
await engine.call([{"role": "user", "content": "Hi"}])
365367

366-
assert exc_info.value.status == 400
367-
assert exc_info.value.model_name == "meta/llama-3.3-70b-instruct"
368+
assert exc_info.value.status == 400
369+
assert exc_info.value.model_name == "meta/llama-3.3-70b-instruct"
368370

369-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"})
370371
@pytest.mark.asyncio
371372
async def test_call_unexpected_exception_wraps_in_model_engine_error(self):
372373
"""Non-HTTP exceptions are wrapped in ModelEngineError."""
373-
engine = ModelEngine(_make_model())
374+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"}):
375+
engine = ModelEngine(_make_model())
374376

375-
mock_client = AsyncMock()
376-
mock_client.post = MagicMock(side_effect=RuntimeError("connection dropped"))
377-
mock_client.closed = False
378-
engine._client = mock_client
379-
engine._running = True
377+
mock_client = AsyncMock()
378+
mock_client.post = MagicMock(side_effect=RuntimeError("connection dropped"))
379+
mock_client.closed = False
380+
engine._client = mock_client
381+
engine._running = True
380382

381-
with pytest.raises(ModelEngineError, match="connection dropped"):
382-
await engine.call([{"role": "user", "content": "Hi"}])
383+
with pytest.raises(ModelEngineError, match="connection dropped"):
384+
await engine.call([{"role": "user", "content": "Hi"}])
383385

384-
@patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"})
385386
@pytest.mark.asyncio
386387
async def test_call_raises_if_not_started(self):
387388
"""call() raises ModelEngineError if start() hasn't been called."""
388-
engine = ModelEngine(_make_model())
389-
assert engine._client is None
389+
with patch.dict("os.environ", {"NVIDIA_API_KEY": "test-key"}):
390+
engine = ModelEngine(_make_model())
391+
assert engine._client is None
390392

391-
with pytest.raises(ModelEngineError, match="has not been started"):
392-
await engine.call([{"role": "user", "content": "Hi"}])
393+
with pytest.raises(ModelEngineError, match="has not been started"):
394+
await engine.call([{"role": "user", "content": "Hi"}])
393395

394396

395397
class TestModelEngineConstants:

0 commit comments

Comments
 (0)