Skip to content

Commit 9a27a2b

Browse files
fern-supportclaude
andcommitted
feat: auto-detect aiohttp transport for async client
When httpx_aiohttp is installed, the async client automatically uses it as the transport layer, providing native async DNS resolution via aiodns instead of httpx's blocking threadpool-based getaddrinfo. This prevents ConnectError failures under high async concurrency. Users opt in with `pip install cohere[aiohttp]` — no code changes needed. Users who want plain httpx can pass `httpx_client=DefaultAsyncHttpxClient()`. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 27e4a7b commit 9a27a2b

File tree

5 files changed

+179
-3
lines changed

5 files changed

+179
-3
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ requests = "^2.0.0"
4747
tokenizers = ">=0.15,<1"
4848
types-requests = "^2.0.0"
4949
typing_extensions = ">= 4.0.0"
50+
aiohttp = {version = ">=3.0", optional = true}
51+
httpx_aiohttp = {version = ">=0.1.8", optional = true}
52+
53+
[tool.poetry.extras]
54+
aiohttp = ["aiohttp", "httpx_aiohttp"]
5055

5156
[tool.poetry.group.dev.dependencies]
5257
mypy = "==1.13.0"

src/cohere/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@
280280
)
281281
from .bedrock_client import BedrockClient, BedrockClientV2
282282
from .client import AsyncClient, Client
283+
from ._default_clients import DefaultAioHttpClient, DefaultAsyncHttpxClient
283284
from .client_v2 import AsyncClientV2, ClientV2
284285
from .datasets import DatasetsCreateResponse, DatasetsGetResponse, DatasetsGetUsageResponse, DatasetsListResponse
285286
from .embed_jobs import CreateEmbedJobRequestTruncate
@@ -440,6 +441,8 @@
440441
"CreateEmbedJobResponse": ".types",
441442
"Dataset": ".types",
442443
"DatasetPart": ".types",
444+
"DefaultAioHttpClient": "._default_clients",
445+
"DefaultAsyncHttpxClient": "._default_clients",
443446
"DatasetType": ".types",
444447
"DatasetValidationStatus": ".types",
445448
"DatasetsCreateResponse": ".datasets",
@@ -779,6 +782,8 @@ def __dir__():
779782
"DatasetsGetResponse",
780783
"DatasetsGetUsageResponse",
781784
"DatasetsListResponse",
785+
"DefaultAioHttpClient",
786+
"DefaultAsyncHttpxClient",
782787
"DebugStreamedChatResponse",
783788
"DebugV2ChatStreamResponse",
784789
"DeleteConnectorResponse",

src/cohere/_default_clients.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import typing
2+
3+
import httpx
4+
5+
COHERE_DEFAULT_TIMEOUT = 300
6+
7+
try:
8+
import httpx_aiohttp
9+
except ImportError:
10+
11+
class DefaultAioHttpClient(httpx.AsyncClient): # type: ignore
12+
def __init__(self, **kwargs: typing.Any) -> None:
13+
raise RuntimeError(
14+
"To use the aiohttp client, install the aiohttp extra: "
15+
"pip install cohere[aiohttp]"
16+
)
17+
18+
else:
19+
20+
class DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore
21+
def __init__(self, **kwargs: typing.Any) -> None:
22+
kwargs.setdefault("timeout", COHERE_DEFAULT_TIMEOUT)
23+
kwargs.setdefault("follow_redirects", True)
24+
super().__init__(**kwargs)
25+
26+
27+
class DefaultAsyncHttpxClient(httpx.AsyncClient):
28+
def __init__(self, **kwargs: typing.Any) -> None:
29+
kwargs.setdefault("timeout", COHERE_DEFAULT_TIMEOUT)
30+
kwargs.setdefault("follow_redirects", True)
31+
super().__init__(**kwargs)

src/cohere/base_client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,24 @@ def finetuning(self):
15651565
return self._finetuning
15661566

15671567

1568+
def _make_default_async_client(
1569+
timeout: float,
1570+
follow_redirects: typing.Optional[bool],
1571+
) -> httpx.AsyncClient:
1572+
try:
1573+
import httpx_aiohttp
1574+
except ImportError:
1575+
pass
1576+
else:
1577+
if follow_redirects is not None:
1578+
return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout, follow_redirects=follow_redirects)
1579+
return httpx_aiohttp.HttpxAiohttpClient(timeout=timeout)
1580+
1581+
if follow_redirects is not None:
1582+
return httpx.AsyncClient(timeout=timeout, follow_redirects=follow_redirects)
1583+
return httpx.AsyncClient(timeout=timeout)
1584+
1585+
15681586
class AsyncBaseCohere:
15691587
"""
15701588
Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propagate to these functions.
@@ -1631,9 +1649,7 @@ def __init__(
16311649
headers=headers,
16321650
httpx_client=httpx_client
16331651
if httpx_client is not None
1634-
else httpx.AsyncClient(timeout=_defaulted_timeout, follow_redirects=follow_redirects)
1635-
if follow_redirects is not None
1636-
else httpx.AsyncClient(timeout=_defaulted_timeout),
1652+
else _make_default_async_client(timeout=_defaulted_timeout, follow_redirects=follow_redirects),
16371653
timeout=_defaulted_timeout,
16381654
)
16391655
self._raw_client = AsyncRawBaseCohere(client_wrapper=self._client_wrapper)

tests/test_aiohttp_autodetect.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import sys
2+
import typing
3+
import unittest
4+
from unittest import mock
5+
6+
import httpx
7+
8+
9+
class TestMakeDefaultAsyncClient(unittest.TestCase):
10+
"""Tests for _make_default_async_client in base_client.py."""
11+
12+
def test_without_httpx_aiohttp_returns_httpx_async_client(self) -> None:
13+
"""When httpx_aiohttp is not installed, returns plain httpx.AsyncClient."""
14+
with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}):
15+
# Re-import to pick up the mocked module state
16+
from cohere.base_client import _make_default_async_client
17+
18+
client = _make_default_async_client(timeout=300, follow_redirects=True)
19+
self.assertIsInstance(client, httpx.AsyncClient)
20+
self.assertEqual(client.timeout.read, 300)
21+
self.assertTrue(client.follow_redirects)
22+
23+
def test_without_httpx_aiohttp_follow_redirects_none(self) -> None:
24+
"""When follow_redirects is None, omits it from httpx.AsyncClient."""
25+
with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}):
26+
from cohere.base_client import _make_default_async_client
27+
28+
client = _make_default_async_client(timeout=300, follow_redirects=None)
29+
self.assertIsInstance(client, httpx.AsyncClient)
30+
# httpx default is False when not specified
31+
self.assertFalse(client.follow_redirects)
32+
33+
def test_with_httpx_aiohttp_returns_aiohttp_client(self) -> None:
34+
"""When httpx_aiohttp is installed, returns HttpxAiohttpClient."""
35+
try:
36+
import httpx_aiohttp
37+
except ImportError:
38+
self.skipTest("httpx_aiohttp not installed")
39+
40+
from cohere.base_client import _make_default_async_client
41+
42+
client = _make_default_async_client(timeout=300, follow_redirects=True)
43+
self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient)
44+
self.assertEqual(client.timeout.read, 300)
45+
self.assertTrue(client.follow_redirects)
46+
47+
def test_with_httpx_aiohttp_follow_redirects_none(self) -> None:
48+
"""When httpx_aiohttp is installed and follow_redirects is None, omits it."""
49+
try:
50+
import httpx_aiohttp
51+
except ImportError:
52+
self.skipTest("httpx_aiohttp not installed")
53+
54+
from cohere.base_client import _make_default_async_client
55+
56+
client = _make_default_async_client(timeout=300, follow_redirects=None)
57+
self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient)
58+
# httpx default is False when not specified
59+
self.assertFalse(client.follow_redirects)
60+
61+
def test_explicit_httpx_client_bypasses_autodetect(self) -> None:
62+
"""When user passes httpx_client explicitly, auto-detect is not used."""
63+
explicit_client = httpx.AsyncClient(timeout=60)
64+
# Simulate what AsyncBaseCohere.__init__ does:
65+
# httpx_client if httpx_client is not None else _make_default_async_client(...)
66+
result = explicit_client if explicit_client is not None else None
67+
self.assertIs(result, explicit_client)
68+
self.assertEqual(result.timeout.read, 60)
69+
70+
71+
class TestDefaultClients(unittest.TestCase):
72+
"""Tests for convenience classes in _default_clients.py."""
73+
74+
def test_default_async_httpx_client_defaults(self) -> None:
75+
"""DefaultAsyncHttpxClient applies SDK defaults."""
76+
from cohere._default_clients import COHERE_DEFAULT_TIMEOUT, DefaultAsyncHttpxClient
77+
78+
client = DefaultAsyncHttpxClient()
79+
self.assertIsInstance(client, httpx.AsyncClient)
80+
self.assertEqual(client.timeout.read, COHERE_DEFAULT_TIMEOUT)
81+
self.assertTrue(client.follow_redirects)
82+
83+
def test_default_async_httpx_client_overrides(self) -> None:
84+
"""DefaultAsyncHttpxClient allows overriding defaults."""
85+
from cohere._default_clients import DefaultAsyncHttpxClient
86+
87+
client = DefaultAsyncHttpxClient(timeout=60, follow_redirects=False)
88+
self.assertEqual(client.timeout.read, 60)
89+
self.assertFalse(client.follow_redirects)
90+
91+
def test_default_aiohttp_client_without_package(self) -> None:
92+
"""DefaultAioHttpClient raises RuntimeError when httpx_aiohttp not installed."""
93+
with mock.patch.dict(sys.modules, {"httpx_aiohttp": None}):
94+
# Need to reload the module to pick up the mock
95+
import importlib
96+
import cohere._default_clients
97+
98+
importlib.reload(cohere._default_clients)
99+
100+
with self.assertRaises(RuntimeError) as ctx:
101+
cohere._default_clients.DefaultAioHttpClient()
102+
self.assertIn("pip install cohere[aiohttp]", str(ctx.exception))
103+
104+
# Reload again to restore original state
105+
importlib.reload(cohere._default_clients)
106+
107+
def test_default_aiohttp_client_with_package(self) -> None:
108+
"""DefaultAioHttpClient works when httpx_aiohttp is installed."""
109+
try:
110+
import httpx_aiohttp
111+
except ImportError:
112+
self.skipTest("httpx_aiohttp not installed")
113+
114+
from cohere._default_clients import COHERE_DEFAULT_TIMEOUT, DefaultAioHttpClient
115+
116+
client = DefaultAioHttpClient()
117+
self.assertIsInstance(client, httpx_aiohttp.HttpxAiohttpClient)
118+
self.assertEqual(client.timeout.read, COHERE_DEFAULT_TIMEOUT)
119+
self.assertTrue(client.follow_redirects)

0 commit comments

Comments
 (0)