Skip to content

Commit b1a7a51

Browse files
Add support for Ollama and support for OpenAI models directly (#65)
* Add OpenAI support to connector_creator.py * Add Ollama support to connector_creator.py
1 parent eacb968 commit b1a7a51

File tree

7 files changed

+116
-29
lines changed

7 files changed

+116
-29
lines changed

examples/connector_creator_usage_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,11 @@ class Joke(BaseModel):
239239

240240
# model_url_and_name = os.getenv("LLAMA_URL")
241241
# model_url_and_name = os.getenv("GIGACHAT_URL")
242-
model_url_and_name = os.getenv("DEEPSEEK_URL")
242+
# model_url_and_name = os.getenv("DEEPSEEK_URL")
243243
# model_url_and_name = os.getenv("DEEPSEEK_R1_URL")
244244
# model_url_and_name = os.getenv("GPT4_URL")
245+
# model_url_and_name = os.getenv("OPENAI_URL")
246+
model_url_and_name = os.getenv("OLLAMA_URL")
245247

246248
# Uncomment the example you want to run
247249
basic_call_example(model_url_and_name)

poetry.lock

Lines changed: 38 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

protollm/connectors/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ model name with a semicolon (;), for example: `https://api.vsegpt.ru/v1;openai/g
2121
It is also possible to pass additional parameters for the model. Available parameters:
2222
- `temperature`
2323
- `top_p`
24+
- `max_tokens`
2425

25-
Before use, make sure that your config file has the necessary API key (`VSEGPT_KEY` by default), or in the case of
26-
Gigachat models, an authorisation key (`AUTHORIZATION_KEY`), which can be obtained from your personal account.
26+
Before use, make sure that your config file has the necessary API key (`VSE_GPT_KEY` by default or `OPENAI_KEY`), or in
27+
the case of Gigachat models, an authorisation key (`AUTHORIZATION_KEY`), which can be obtained from your personal
28+
account.
2729

2830
Example of how to use the function:
29-
```commandline
31+
```codeblock
3032
from protollm.connectors.connector_creator import create_llm_connector
3133
3234
model = create_llm_connector("https://api.vsegpt.ru/v1;openai/gpt-4o-mini", temperature=0.015, top_p=0.95)
@@ -35,8 +37,6 @@ print(res.content)
3537
```
3638
The rest of the examples are located in the `examples/connector_creator_usage_examples.py` module of the repository.
3739

38-
39-
4040
## New connectors
4141

4242
For now connectors are available for services supporting the OpenAI API format, as well as for Gigachat family models.

protollm/connectors/connector_creator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from langchain_core.tools import BaseTool
99
from langchain_core.runnables import Runnable
1010
from langchain_gigachat import GigaChat
11+
from langchain_ollama import ChatOllama
1112
from langchain_openai import ChatOpenAI
1213
from pydantic import BaseModel, ValidationError
1314

@@ -254,7 +255,7 @@ def _handle_system_prompt(msgs, sys_prompt):
254255
return msgs
255256

256257

257-
def create_llm_connector(model_url: str, *args: Any, **kwargs: Any) -> CustomChatOpenAI | GigaChat:
258+
def create_llm_connector(model_url: str, *args: Any, **kwargs: Any) -> CustomChatOpenAI | GigaChat | ChatOpenAI:
258259
"""Creates the proper connector for a given LLM service URL.
259260
260261
Args:
@@ -263,6 +264,8 @@ def create_llm_connector(model_url: str, *args: Any, **kwargs: Any) -> CustomCha
263264
- for Gigachat models family: 'https://gigachat.devices.sberbank.ru/api/v1/chat/completions;Gigachat'
264265
for Gigachat model you should also install certificates from 'НУЦ Минцифры' -
265266
instructions - 'https://developers.sber.ru/docs/ru/gigachat/certificates'
267+
- for OpenAI for example: 'https://api.openai.com/v1;gpt-4o'
268+
- for Ollama for example: 'ollama;http://localhost:11434;llama3.2'
266269
267270
Returns:
268271
The ChatModel object from 'langchain' that can be used to make requests to the LLM service,
@@ -276,6 +279,12 @@ def create_llm_connector(model_url: str, *args: Any, **kwargs: Any) -> CustomCha
276279
model_name = model_url.split(";")[1]
277280
access_token = get_access_token()
278281
return GigaChat(model=model_name, access_token=access_token, *args, **kwargs)
282+
elif "api.openai" in model_url:
283+
model_name = model_url.split(";")[1]
284+
return ChatOpenAI(model=model_name, api_key=os.getenv("OPENAI_KEY"), *args, **kwargs)
285+
elif "ollama" in model_url:
286+
url_and_name = model_url.split(";")
287+
return ChatOllama(model=url_and_name[2], base_url=url_and_name[1], *args, **kwargs)
279288
elif model_url == "test_model":
280289
return CustomChatOpenAI(model_name=model_url, api_key="test")
281290
else:

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ProtoLLM"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "A library with which to prototype LLM-based applications quickly and easily."
55
requires-python = ">=3.10,<4.0"
66
authors = [
@@ -24,7 +24,7 @@ dependencies = [
2424
"langchain>=0.3.4,<0.4.0",
2525
"langchain-chroma==0.1.4",
2626
"langchain-community==0.3.16",
27-
"langchain-core==0.3.34",
27+
"langchain-core>=0.3.34",
2828
"langchain-elasticsearch==0.3.2",
2929
"langchain-gigachat==0.3.3",
3030
"langchain-openai==0.3.3",
@@ -55,7 +55,8 @@ dependencies = [
5555
"transformers==4.48.2",
5656
"urllib3>=2.2.2,<3.0.0",
5757
"uuid>=1.30,<2.0.0",
58-
"websockets==14.1"
58+
"websockets==14.1",
59+
"langchain-ollama (==0.3.0)"
5960
]
6061

6162
[project.urls]

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ kombu==5.4.2 ; python_version >= "3.12" and python_version < "4.0" or python_ver
8383
kubernetes==32.0.1 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
8484
langchain-chroma==0.1.4 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
8585
langchain-community==0.3.16 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
86-
langchain-core==0.3.34 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
86+
langchain-core==0.3.48 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
8787
langchain-elasticsearch==0.3.2 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
8888
langchain-gigachat==0.3.3 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
89+
langchain-ollama==0.3.0 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
8990
langchain-openai==0.3.3 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
9091
langchain-text-splitters==0.3.6 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
9192
langchain==0.3.18 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
@@ -125,6 +126,7 @@ networkx==3.4.2 ; python_version >= "3.12" and python_version < "3.14" or python
125126
nltk==3.9.1 ; python_version >= "3.12" and python_version < "3.14" or python_version >= "3.10" and python_version <= "3.11"
126127
numpy==1.26.4 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
127128
oauthlib==3.2.2 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
129+
ollama==0.4.7 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
128130
onnxruntime==1.21.0 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
129131
openai==1.61.1 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"
130132
opentelemetry-api==1.31.0 ; python_version >= "3.12" and python_version < "4.0" or python_version >= "3.10" and python_version <= "3.11"

tests/test_connector.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from unittest.mock import patch
22

3-
from langchain_core.language_models.chat_models import BaseChatModel
43
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
54
from langchain_core.tools import tool
5+
from langchain_gigachat import GigaChat
6+
from langchain_ollama import ChatOllama
7+
from langchain_openai import ChatOpenAI
68
from pydantic import BaseModel, Field
79
import pytest
810

@@ -307,16 +309,53 @@ def test_structured_output_dict_out_of_the_box(custom_chat_openai_with_fc_and_so
307309
assert result["age"] == 30
308310

309311

310-
@pytest.mark.parametrize(
311-
"model_url",
312-
[
313-
"https://api.vsegpt.ru/v1;openai/gpt-4o-mini",
314-
"https://gigachat.devices.sberbank.ru/api/v1/chat/completions;GigaChat",
315-
"test_model",
316-
"https://example.com/v1;test/example_model"
317-
]
318-
)
319-
def test_connector_creator(model_url):
320-
with pytest.raises(Exception):
321-
connector = create_llm_connector(model_url)
322-
assert issubclass(connector, BaseChatModel)
312+
def test_vsegpt_connector(monkeypatch):
313+
model_url = "https://api.vsegpt.ru/v1;meta-llama/llama-3.1-70b-instruct"
314+
test_api_key = "test_vsegpt_key"
315+
monkeypatch.setenv("VSE_GPT_KEY", test_api_key)
316+
connector = create_llm_connector(model_url)
317+
assert isinstance(connector, CustomChatOpenAI)
318+
319+
320+
@patch("protollm.connectors.connector_creator.get_access_token", return_value="test_gigachat_token")
321+
def test_gigachat_connector(mock_get_token):
322+
model_url = "https://gigachat.devices.sberbank.ru/api/v1;Gigachat"
323+
connector = create_llm_connector(model_url)
324+
assert isinstance(connector, GigaChat)
325+
326+
327+
def test_openai_connector(monkeypatch):
328+
model_url = "https://api.openai.com/v1;gpt-4o"
329+
test_api_key = "test_openai_key"
330+
monkeypatch.setenv("OPENAI_KEY", test_api_key)
331+
connector = create_llm_connector(model_url)
332+
assert isinstance(connector, ChatOpenAI)
333+
334+
335+
def test_ollama_connector():
336+
model_url = "ollama;http://localhost:11434;llama3.2"
337+
connector = create_llm_connector(model_url)
338+
assert isinstance(connector, ChatOllama)
339+
340+
341+
def test_test_model_connector():
342+
model_url = "test_model"
343+
connector = create_llm_connector(model_url)
344+
assert isinstance(connector, CustomChatOpenAI)
345+
346+
347+
def test_unsupported_provider():
348+
model_url = "https://unknown.provider/v1;some-model"
349+
with pytest.raises(ValueError) as exc_info:
350+
create_llm_connector(model_url)
351+
assert "Unsupported provider URL" in str(exc_info.value)
352+
353+
354+
@pytest.mark.parametrize("invalid_url", [
355+
"invalid_url_without_semicolon",
356+
"https://api.vsegpt.ru/v1",
357+
";;",
358+
])
359+
def test_invalid_url_format(invalid_url):
360+
with pytest.raises(ValueError):
361+
create_llm_connector(invalid_url)

0 commit comments

Comments
 (0)