|
10 | 10 | import os |
11 | 11 | from contextlib import asynccontextmanager |
12 | 12 | from unittest.mock import AsyncMock, MagicMock, patch |
| 13 | +import aiohttp |
13 | 14 | import pytest |
| 15 | +from cohere.base_client import AsyncBaseCohere |
14 | 16 | from cohere.client import AsyncClient |
15 | 17 | from cohere.client_v2 import AsyncClientV2 |
| 18 | +from cohere.core.http_client import AsyncHttpClient |
16 | 19 | from cohere.errors import BadRequestError, UnauthorizedError, TooManyRequestsError, InternalServerError |
17 | 20 | from cohere.core.api_error import ApiError |
18 | 21 |
|
@@ -536,5 +539,76 @@ async def mock_stream_context(*args, **kwargs): |
536 | 539 | assert mock_response.json.call_count >= 1 |
537 | 540 |
|
538 | 541 |
|
| 542 | +# ---- Regression tests for aio-libs/aiohttp#1142: force_close / allow_redirects fix ---- |
| 543 | + |
| 544 | + |
| 545 | +@pytest.mark.asyncio |
| 546 | +async def test_connector_never_force_closes_connections(): |
| 547 | + """ |
| 548 | + Regression: TCPConnector must NOT receive force_close=True regardless of the |
| 549 | + follow_redirects setting. The previous implementation on AsyncBaseCohere used |
| 550 | + `force_close=not follow_redirects`, which disabled keep-alive connection pooling |
| 551 | + and caused TIME_WAIT port exhaustion when making thousands of concurrent requests. |
| 552 | + Redirect behaviour is now handled per-request via allow_redirects instead. |
| 553 | + """ |
| 554 | + for follow_redirects_val in (True, False): |
| 555 | + with patch("aiohttp.TCPConnector") as mock_connector_cls, patch("aiohttp.ClientSession"): |
| 556 | + mock_connector_cls.return_value = MagicMock() |
| 557 | + # AsyncBaseCohere is where follow_redirects is consumed and the connector is built |
| 558 | + AsyncBaseCohere(token="test-key", follow_redirects=follow_redirects_val) |
| 559 | + |
| 560 | + call_kwargs = mock_connector_cls.call_args.kwargs if mock_connector_cls.call_args else {} |
| 561 | + assert call_kwargs.get("force_close", False) is not True, ( |
| 562 | + f"force_close must not be True when follow_redirects={follow_redirects_val!r}; " |
| 563 | + "this kills TCP keep-alive and exhausts ephemeral ports under high concurrency" |
| 564 | + ) |
| 565 | + |
| 566 | + |
| 567 | +@pytest.mark.asyncio |
| 568 | +async def test_allow_redirects_forwarded_per_request(): |
| 569 | + """ |
| 570 | + follow_redirects on the client is forwarded as allow_redirects= on each |
| 571 | + aiohttp session.request() call. This is the correct per-request mechanism; |
| 572 | + the connector-level force_close must NOT be used for this purpose. |
| 573 | + """ |
| 574 | + mock_response = AsyncMock() |
| 575 | + mock_response.status = 200 |
| 576 | + mock_response.status_code = 200 |
| 577 | + mock_response.headers = {} |
| 578 | + mock_response.read = AsyncMock(return_value=b"{}") |
| 579 | + mock_response.content_type = "application/json" |
| 580 | + mock_response.text = "{}" |
| 581 | + |
| 582 | + for follow_redirects_val in (True, False): |
| 583 | + captured_request_kwargs: dict = {} |
| 584 | + |
| 585 | + async def capture_request(*args, **kwargs): |
| 586 | + captured_request_kwargs.update(kwargs) |
| 587 | + return mock_response |
| 588 | + |
| 589 | + mock_session = MagicMock() |
| 590 | + mock_session.request = capture_request |
| 591 | + |
| 592 | + http_client = AsyncHttpClient( |
| 593 | + aiohttp_session=mock_session, |
| 594 | + base_timeout=lambda: 30.0, |
| 595 | + base_headers=lambda: {"Authorization": "Bearer test"}, |
| 596 | + base_url=lambda: "https://api.cohere.com", |
| 597 | + follow_redirects=follow_redirects_val, |
| 598 | + ) |
| 599 | + try: |
| 600 | + await http_client.request(method="GET", path="/v1/chat") |
| 601 | + except Exception: |
| 602 | + pass # response parsing may fail; we only care about the kwargs forwarded |
| 603 | + |
| 604 | + assert "allow_redirects" in captured_request_kwargs, ( |
| 605 | + "allow_redirects must be passed to session.request() — " |
| 606 | + "redirect handling belongs at the request level, not the connector level" |
| 607 | + ) |
| 608 | + assert captured_request_kwargs["allow_redirects"] == follow_redirects_val, ( |
| 609 | + f"allow_redirects should equal follow_redirects={follow_redirects_val!r}" |
| 610 | + ) |
| 611 | + |
| 612 | + |
539 | 613 | if __name__ == "__main__": |
540 | 614 | asyncio.run(main()) |
0 commit comments