Skip to content

Commit 47f518a

Browse files
committed
feat: add LLM support for GLM-4 series
1 parent 082a9a4 commit 47f518a

File tree

13 files changed

+79
-272
lines changed

13 files changed

+79
-272
lines changed

aperag/llm/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def match_predictor(model_name, predictor_type, kwargs):
8383
case "deepseek-chat" | "gpt-4-1106-preview" | "gpt-4-vision-preview" | "gpt-4" | "gpt-4-32k" | "gpt-4-0613" | "gpt-4-32k-0613":
8484
from aperag.llm.openai import OpenAIPredictor
8585
return OpenAIPredictor
86+
case "glm-4-plus" | "glm-4-air" | "glm-4-long" | "glm-4-flashx" | "glm-4-flash":
87+
from aperag.llm.openai import OpenAIPredictor
88+
return OpenAIPredictor
8689
case "azure-openai":
8790
from aperag.llm.azure import AzureOpenAIPredictor
8891
return AzureOpenAIPredictor
@@ -92,10 +95,6 @@ def match_predictor(model_name, predictor_type, kwargs):
9295
case "ernie-bot-turbo":
9396
from aperag.llm.wenxin import BaiduQianFan
9497
return BaiduQianFan
95-
case "chatglm-pro" | "chatglm-std" | "chatglm-lite" | "chatglm-turbo":
96-
kwargs["model"] = model_name.replace("-", "_")
97-
from aperag.llm.chatglm import ChatGLMPredictor
98-
return ChatGLMPredictor
9998
case "qwen-turbo" | "qwen-plus" | "qwen-max":
10099
from aperag.llm.qianwen import QianWenPredictor
101100
return QianWenPredictor

aperag/llm/chatglm.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

aperag/llm/test_agenerate_stream.py

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from base import KubeBlocksLLMPredictor
2323

2424
from aperag.llm.baichuan import BaiChuanPredictor
25-
from aperag.llm.chatglm import ChatGLMPredictor
2625
from aperag.llm.custom import CustomLLMPredictor
2726
from aperag.llm.openai import OpenAIPredictor
2827
from aperag.llm.wenxin import BaiduQianFan
@@ -306,80 +305,5 @@ async def test_stream_behavior():
306305
self.assertEqual(task_order_log, ["task2 completed.", "task1 completed."])
307306

308307

309-
class TestChatGLMPredictor(unittest.IsolatedAsyncioTestCase):
310-
311-
async def test_stream_async_behabior(self):
312-
predictor1 = ChatGLMPredictor(api_key="id.secret", endpoint="http://192.0.2.0")
313-
predictor2 = ChatGLMPredictor(api_key="id.secret")
314-
315-
task_order_log = []
316-
317-
async def async_task():
318-
319-
try:
320-
_ = [tokens async for tokens in predictor1.agenerate_stream(prompt="test")]
321-
except aiohttp.ClientConnectorError:
322-
pass
323-
task_order_log.append("task1 completed.") # 在尝试连接多次(超过60s)后结束,打印log信息
324-
325-
async def test_stream_behavior():
326-
mock_response = [
327-
"id: fb981fde-0080-4933-b87b-4a29eaba8d17",
328-
"event: add",
329-
"data: Kubernetes的核心技术",
330-
""
331-
"id: fb981fde-0080-4933-b87b-4a29eaba8d17",
332-
"event: add",
333-
"data: Service的作用是防止Pod",
334-
""
335-
"id: fb981fde-0080-4933-b87b-4a29eaba8d17",
336-
"event: add",
337-
"data: 失联(服务发现)",
338-
""
339-
"id: fb981fde-0080-4933-b87b-4a29eaba8d17",
340-
"event: add",
341-
"data: 和定义Pod访问策略",
342-
""
343-
"id: fb981fde-0080-4933-b87b-4a29eaba8d17",
344-
"event: add",
345-
"data: (负载均衡)。",
346-
""
347-
"id: fb981fde-0080-4933-b87b-4a29eaba8d17",
348-
"event: finish",
349-
]
350-
351-
mock_responses = [
352-
"Kubernetes的核心技术",
353-
"Service的作用是防止Pod",
354-
"失联(服务发现)",
355-
"和定义Pod访问策略",
356-
"(负载均衡)。"
357-
]
358-
359-
# 将数据转换为字节流,每一行后面都有一个换行符
360-
mock_content = "\n".join(mock_response).encode("utf-8")
361-
362-
url = "https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_lite/sse-invoke?temperature=0.95&top_p=0.7"
363-
prompt = "test prompt"
364-
365-
with aioresponses() as mocked:
366-
mocked.post(url=url, body=mock_content, status=200)
367-
368-
response_list = []
369-
async for resp in predictor2._agenerate_stream(prompt=prompt):
370-
response_list.append(resp)
371-
372-
self.assertEqual(response_list, mock_responses)
373-
task_order_log.append("task2 completed.") # 任务执行结束,打印日志信息
374-
375-
# 使用gather同时启动两个任务
376-
# 如果agenerate_stream是异步的,那么在task1多次尝试连接期间,task2就已经在执行中了
377-
# 那么task2一定比task1先结束: (1和2同时开始)---2结束--------1结束
378-
379-
_, _ = await asyncio.gather(async_task(), test_stream_behavior())
380-
381-
self.assertEqual(task_order_log, ["task2 completed.", "task1 completed."])
382-
383-
384308
if __name__ == "__main__":
385309
unittest.main()

aperag/views/main.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def list_models(request):
178178
"enabled": model_server.get("enabled", "true").lower() == "true",
179179
"memory": model_server.get("memory", "disabled").lower() == "enabled",
180180
"free_tier": model_server.get("free_tier", False),
181+
"endpoint": model_server.get("endpoint", ""),
181182
"default_token": Predictor.check_default_token(model_name=model_server["name"]),
182183
"prompt_template": DEFAULT_MODEL_MEMOTY_PROMPT_TEMPLATES.get(model_server["name"],
183184
DEFAULT_CHINESE_PROMPT_TEMPLATE_V3),
@@ -434,7 +435,7 @@ async def update_collection(request, collection_id, collection: CollectionIn):
434435
bot_ids = []
435436
async for bot in bots:
436437
bot_ids.append(bot.id)
437-
438+
438439
return success(instance.view(bot_ids=bot_ids))
439440

440441

@@ -467,27 +468,27 @@ async def create_questions(request, collection_id):
467468
return fail(HTTPStatus.NOT_FOUND, "Collection not found")
468469
if collection.status == CollectionStatus.QUESTION_PENDING:
469470
return fail(HTTPStatus.BAD_REQUEST, "Collection is generating questions")
470-
471+
471472
collection.status = CollectionStatus.QUESTION_PENDING
472473
await collection.asave()
473-
474+
474475
documents = await sync_to_async(collection.document_set.exclude)(status=DocumentStatus.DELETED)
475476
generate_tasks = []
476477
async for document in documents:
477478
generate_tasks.append(generate_questions.si(document.id))
478479
generate_group = group(*generate_tasks)
479480
callback_chain = chain(generate_group, update_collection_status.s(collection.id))
480481
callback_chain.delay()
481-
482-
return success({})
482+
483+
return success({})
483484

484485
@router.put("/collections/{collection_id}/questions")
485486
async def update_question(request, collection_id, question_in: QuestionIn):
486487
user = get_user(request)
487488
collection = await query_collection(user, collection_id)
488489
if collection is None:
489490
return fail(HTTPStatus.NOT_FOUND, "Collection not found")
490-
491+
491492
# ceate question
492493
if not question_in.id:
493494
question_instance = Question(
@@ -499,13 +500,13 @@ async def update_question(request, collection_id, question_in: QuestionIn):
499500
else:
500501
question_instance = await query_question(user, question_in.id)
501502
if question_instance is None:
502-
return fail(HTTPStatus.NOT_FOUND, "Question not found")
503-
503+
return fail(HTTPStatus.NOT_FOUND, "Question not found")
504+
504505
question_instance.question = question_in.question
505506
question_instance.answer = question_in.answer if question_in.answer else ""
506507
question_instance.status = QuestionStatus.PENDING
507508
await sync_to_async(question_instance.documents.clear)()
508-
509+
509510
if question_in.relate_documents:
510511
for document_id in question_in.relate_documents:
511512
document = await query_document(user, collection_id, document_id)
@@ -688,7 +689,7 @@ async def update_document(
688689
await instance.asave()
689690
# if user add labels for a document, we need to update index
690691
update_index_for_document.delay(instance.id)
691-
692+
692693
related_questions = await sync_to_async(document.question_set.exclude)(status=QuestionStatus.DELETED)
693694
async for question in related_questions:
694695
question.status = QuestionStatus.WARNING
@@ -712,13 +713,13 @@ async def delete_document(request, collection_id, document_id):
712713
await document.asave()
713714

714715
remove_index.delay(document.id)
715-
716+
716717
related_questions = await sync_to_async(document.question_set.exclude)(status=QuestionStatus.DELETED)
717718
async for question in related_questions:
718719
question.documents.remove(document)
719720
question.status = QuestionStatus.WARNING
720721
await question.asave()
721-
722+
722723
return success(document.view())
723724

724725

@@ -736,13 +737,13 @@ async def delete_documents(request, collection_id, document_ids: List[str]):
736737
document.gmt_deleted = timezone.now()
737738
await document.asave()
738739
remove_index.delay(document.id)
739-
740+
740741
related_questions = await sync_to_async(document.question_set.exclude)(status=QuestionStatus.DELETED)
741742
async for question in related_questions:
742743
question.documents.remove(document)
743744
question.status = QuestionStatus.WARNING
744745
await question.asave()
745-
746+
746747
ok.append(document.id)
747748
except Exception as e:
748749
logger.exception(e)

0 commit comments

Comments
 (0)