Skip to content

Commit fcb07a7

Browse files
committed
Adding tests too
1 parent bdc887e commit fcb07a7

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

tests/test_async_client.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
import os
1111
from contextlib import asynccontextmanager
1212
from unittest.mock import AsyncMock, MagicMock, patch
13+
import aiohttp
1314
import pytest
15+
from cohere.base_client import AsyncBaseCohere
1416
from cohere.client import AsyncClient
1517
from cohere.client_v2 import AsyncClientV2
18+
from cohere.core.http_client import AsyncHttpClient
1619
from cohere.errors import BadRequestError, UnauthorizedError, TooManyRequestsError, InternalServerError
1720
from cohere.core.api_error import ApiError
1821

@@ -536,5 +539,76 @@ async def mock_stream_context(*args, **kwargs):
536539
assert mock_response.json.call_count >= 1
537540

538541

542+
# ---- Regression tests for aio-libs/aiohttp#1142: force_close / allow_redirects fix ----
543+
544+
545+
@pytest.mark.asyncio
546+
async def test_connector_never_force_closes_connections():
547+
"""
548+
Regression: TCPConnector must NOT receive force_close=True regardless of the
549+
follow_redirects setting. The previous implementation on AsyncBaseCohere used
550+
`force_close=not follow_redirects`, which disabled keep-alive connection pooling
551+
and caused TIME_WAIT port exhaustion when making thousands of concurrent requests.
552+
Redirect behaviour is now handled per-request via allow_redirects instead.
553+
"""
554+
for follow_redirects_val in (True, False):
555+
with patch("aiohttp.TCPConnector") as mock_connector_cls, patch("aiohttp.ClientSession"):
556+
mock_connector_cls.return_value = MagicMock()
557+
# AsyncBaseCohere is where follow_redirects is consumed and the connector is built
558+
AsyncBaseCohere(token="test-key", follow_redirects=follow_redirects_val)
559+
560+
call_kwargs = mock_connector_cls.call_args.kwargs if mock_connector_cls.call_args else {}
561+
assert call_kwargs.get("force_close", False) is not True, (
562+
f"force_close must not be True when follow_redirects={follow_redirects_val!r}; "
563+
"this kills TCP keep-alive and exhausts ephemeral ports under high concurrency"
564+
)
565+
566+
567+
@pytest.mark.asyncio
568+
async def test_allow_redirects_forwarded_per_request():
569+
"""
570+
follow_redirects on the client is forwarded as allow_redirects= on each
571+
aiohttp session.request() call. This is the correct per-request mechanism;
572+
the connector-level force_close must NOT be used for this purpose.
573+
"""
574+
mock_response = AsyncMock()
575+
mock_response.status = 200
576+
mock_response.status_code = 200
577+
mock_response.headers = {}
578+
mock_response.read = AsyncMock(return_value=b"{}")
579+
mock_response.content_type = "application/json"
580+
mock_response.text = "{}"
581+
582+
for follow_redirects_val in (True, False):
583+
captured_request_kwargs: dict = {}
584+
585+
async def capture_request(*args, **kwargs):
586+
captured_request_kwargs.update(kwargs)
587+
return mock_response
588+
589+
mock_session = MagicMock()
590+
mock_session.request = capture_request
591+
592+
http_client = AsyncHttpClient(
593+
aiohttp_session=mock_session,
594+
base_timeout=lambda: 30.0,
595+
base_headers=lambda: {"Authorization": "Bearer test"},
596+
base_url=lambda: "https://api.cohere.com",
597+
follow_redirects=follow_redirects_val,
598+
)
599+
try:
600+
await http_client.request(method="GET", path="/v1/chat")
601+
except Exception:
602+
pass # response parsing may fail; we only care about the kwargs forwarded
603+
604+
assert "allow_redirects" in captured_request_kwargs, (
605+
"allow_redirects must be passed to session.request() — "
606+
"redirect handling belongs at the request level, not the connector level"
607+
)
608+
assert captured_request_kwargs["allow_redirects"] == follow_redirects_val, (
609+
f"allow_redirects should equal follow_redirects={follow_redirects_val!r}"
610+
)
611+
612+
539613
if __name__ == "__main__":
540614
asyncio.run(main())

0 commit comments

Comments
 (0)