Skip to content

Commit 107d351

Browse files
committed
imporve code for memory and prompt
1 parent ffadeb5 commit 107d351

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

kubechat/chat/websocket/common_consumer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ async def predict(self, query, **kwargs):
3131
async def receive(self, text_data=None, bytes_data=None):
3232
message = ""
3333
message_id = f"{now_unix_milliseconds()}"
34+
related_question = []
3435

3536
# 先在text_data里传file_name,再在bytes_data里传file_content,最后问问题
3637
if bytes_data:
@@ -95,4 +96,4 @@ async def receive(self, text_data=None, bytes_data=None):
9596
if self.use_default_token and self.conversation_limit:
9697
await self.manage_quota_usage()
9798
# send stop message
98-
await self.send(text_data=stop_response(message_id, [], related_question, self.pipeline.memory_count))
99+
await self.send(text_data=stop_response(message_id, [], related_question, self.related_question_prompt, self.pipeline.memory_count))

kubechat/pipeline/common_pipeline.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@ def __init__(self, **kwargs):
1717
super().__init__(**kwargs)
1818

1919
# self.prompt_template=None
20-
if not self.prompt_template:
21-
if self.memory:
22-
self.prompt_template = COMMON_MEMORY_TEMPLATE
23-
else:
24-
self.prompt_template = COMMON_TEMPLATE
20+
if self.memory:
21+
self.prompt_template = COMMON_MEMORY_TEMPLATE
22+
elif not self.prompt_template:
23+
self.prompt_template = COMMON_TEMPLATE
2524
self.prompt = PromptTemplate(template=self.prompt_template, input_variables=["query"])
2625
self.file_prompt = PromptTemplate(template=COMMON_FILE_TEMPLATE,
2726
input_variables=["query", "context"])
@@ -51,6 +50,9 @@ async def run(self, message, gen_references=False, message_id="", file=None):
5150
related_questions = []
5251
response = ""
5352

53+
need_generate_answer = True
54+
need_related_question = True
55+
5456
if self.oops != "":
5557
response = self.oops
5658
yield self.oops

kubechat/views/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ async def update_bot(request, bot_id, bot_in: BotIn):
712712
new_config = json.loads(bot_in.config)
713713
model = new_config.get("model")
714714
llm_config = new_config.get("llm")
715-
valid, msg = validate_bot_config(model, llm_config)
715+
valid, msg = validate_bot_config(model, llm_config, bot)
716716
if not valid:
717717
return fail(HTTPStatus.BAD_REQUEST, msg)
718718
old_config = json.loads(bot.config)

kubechat/views/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from config import settings
1313
from kubechat.chat.history.redis import RedisChatMessageHistory
1414
from kubechat.llm.base import Predictor, PredictorType
15-
from kubechat.db.models import ssl_file_path, ssl_temp_file_path
15+
from kubechat.db.models import ssl_file_path, ssl_temp_file_path, BotType
1616
from kubechat.db.ops import query_chat_feedbacks, logger, PagedResult
1717
from kubechat.source.base import get_source, CustomSourceInitializationError
1818
from kubechat.utils.utils import AVAILABLE_SOURCE
@@ -82,7 +82,7 @@ def validate_source_connect_config(config: Dict) -> (bool, str):
8282
return True, ""
8383

8484

85-
def validate_bot_config(model, config: Dict) -> (bool, str):
85+
def validate_bot_config(model, config: Dict, bot) -> (bool, str):
8686
try:
8787
Predictor.from_model(model, PredictorType.CUSTOM_LLM, **config)
8888
except Exception as e:
@@ -91,14 +91,20 @@ def validate_bot_config(model, config: Dict) -> (bool, str):
9191
try:
9292
# validate the prompt
9393
prompt_template = config.get("prompt_template", None)
94-
PromptTemplate(template=prompt_template, input_variables=["query", "context"])
94+
if bot.type == BotType.KNOWLEDGE:
95+
PromptTemplate(template=prompt_template, input_variables=["query", "context"])
96+
elif bot.type == BotType.COMMON:
97+
# PromptTemplate(template=prompt_template, input_variables=["query"])
98+
pass
99+
else:
100+
return False, "Unsupported bot type"
95101
except ValidationError:
96102
return False, "Invalid prompt template"
97103

98104
try:
99105
# validate the memory prompt
100106
prompt_template = config.get("memory_prompt_template", None)
101-
if prompt_template:
107+
if prompt_template and bot.type == BotType.KNOWLEDGE:
102108
PromptTemplate(template=prompt_template, input_variables=["query", "context"])
103109
except ValidationError:
104110
return False, "Invalid memory prompt template"

0 commit comments

Comments
 (0)