Skip to content

Commit 03edc9a

Browse files
Fix constructor type
1 parent d144668 commit 03edc9a

File tree

2 files changed

+71
-13
lines changed

2 files changed

+71
-13
lines changed

src/cohere/client_v2.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,61 @@
11
from .client import Client, AsyncClient
22
from .v2.client import V2Client, AsyncV2Client
3+
import typing
4+
from .environment import ClientEnvironment
5+
import os
6+
import httpx
7+
from concurrent.futures import ThreadPoolExecutor
38

49

510
class ClientV2(V2Client, Client): # type: ignore
6-
__init__ = Client.__init__ # type: ignore
11+
def __init__(
12+
self,
13+
api_key: typing.Optional[typing.Union[str,
14+
typing.Callable[[], str]]] = None,
15+
*,
16+
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
17+
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
18+
client_name: typing.Optional[str] = None,
19+
timeout: typing.Optional[float] = None,
20+
httpx_client: typing.Optional[httpx.Client] = None,
21+
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
22+
log_warning_experimental_features: bool = True,
23+
):
24+
Client.__init__(
25+
self,
26+
api_key=api_key,
27+
base_url=base_url,
28+
environment=environment,
29+
client_name=client_name,
30+
timeout=timeout,
31+
httpx_client=httpx_client,
32+
thread_pool_executor=thread_pool_executor,
33+
log_warning_experimental_features=log_warning_experimental_features,
34+
)
735

836

937
class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore
10-
__init__ = AsyncClient.__init__ # type: ignore
38+
def __init__(
39+
self,
40+
api_key: typing.Optional[typing.Union[str,
41+
typing.Callable[[], str]]] = None,
42+
*,
43+
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
44+
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
45+
client_name: typing.Optional[str] = None,
46+
timeout: typing.Optional[float] = None,
47+
httpx_client: typing.Optional[httpx.AsyncClient] = None,
48+
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
49+
log_warning_experimental_features: bool = True,
50+
):
51+
AsyncClient.__init__(
52+
self,
53+
api_key=api_key,
54+
base_url=base_url,
55+
environment=environment,
56+
client_name=client_name,
57+
timeout=timeout,
58+
httpx_client=httpx_client,
59+
thread_pool_executor=thread_pool_executor,
60+
log_warning_experimental_features=log_warning_experimental_features,
61+
)

tests/test_client_v2.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414
class TestClientV2(unittest.TestCase):
1515

1616
def test_chat(self) -> None:
17-
response = co.chat(model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])
17+
response = co.chat(
18+
model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])
1819

1920
print(response.message)
2021

2122
def test_chat_stream(self) -> None:
22-
stream = co.chat_stream(model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])
23+
stream = co.chat_stream(
24+
model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])
2325

2426
events = set()
2527

2628
for chat_event in stream:
2729
if chat_event is not None:
2830
events.add(chat_event.type)
2931
if chat_event.type == "content-delta":
30-
print(chat_event.delta.message)
32+
print(chat_event.delta)
3133

3234
self.assertTrue("message-start" in events)
3335
self.assertTrue("content-start" in events)
@@ -43,10 +45,12 @@ def test_chat_documents(self) -> None:
4345
{"title": "widget sales 2021", "text": "4 million"},
4446
]
4547
response = co.chat(
46-
messages=cohere.UserChatMessageV2(
47-
content=cohere.TextContent(text="how many widges were sold in 2020?"),
48+
messages=[cohere.UserChatMessageV2(
49+
content=cohere.TextContent(
50+
text="how many widges were sold in 2020?"),
4851
documents=documents,
49-
),
52+
)],
53+
model="command-r-plus",
5054
)
5155

5256
print(response.message)
@@ -75,9 +79,12 @@ def test_chat_tools(self) -> None:
7579

7680
# call the get_weather tool
7781
tool_result = {"temperature": "30C"}
78-
tool_content = [cohere.Content(output=tool_result, text="The weather in Toronto is 30C")]
79-
messages.append(res.message)
80-
messages.append(cohere.ToolChatMessageV2(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content))
81-
82-
res = co.chat(tools=tools, messages=messages)
82+
tool_content = [cohere.Content(
83+
output=tool_result, text="The weather in Toronto is 30C")]
84+
messages.append(cohere.AssistantChatMessageV2(content=res.message))
85+
if res.message.tool_calls is not None:
86+
messages.append(cohere.ToolChatMessageV2(
87+
tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content))
88+
89+
res = co.chat(tools=tools, messages=messages, model="command-r-plus")
8390
print(res.message)

0 commit comments

Comments
 (0)