Skip to content

Commit e978197

Browse files
authored
feat: related question prompt (#560)
1 parent 1d568bb commit e978197

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

kubechat/chat/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def start_response(message_id):
4444
)
4545

4646

47-
def stop_response(message_id, references, related_question = [], memory_count=0):
47+
def stop_response(message_id, references, related_question = [], related_question_prompt = '' ,memory_count=0):
4848
if references is None:
4949
references = []
5050
return json.dumps(
@@ -53,6 +53,7 @@ def stop_response(message_id, references, related_question = [], memory_count=0)
5353
"id": message_id,
5454
"data": references,
5555
"memoryCount": memory_count,
56+
"related_question_prompt":related_question_prompt,
5657
"related_question":related_question,
5758
"timestamp": now_unix_milliseconds()
5859
}

kubechat/chat/websocket/base_consumer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from abc import abstractmethod
55
import ast
66
from datetime import datetime
7+
import random
78

89
import websockets
910
from asgiref.sync import sync_to_async
@@ -80,12 +81,18 @@ async def connect(self):
8081

8182
message_id = f"{now_unix_milliseconds()}"
8283
bot_config = json.loads(self.bot.config)
84+
self.related_question_prompt = bot_config.get("related_question_prompt","")
85+
8386
welcome = bot_config.get("welcome",{})
8487
faq = welcome.get("faq",[])
8588
questions = []
8689
for qa in faq:
8790
questions.append(qa["question"])
88-
welcome_message = {"hello":welcome.get("hello",""), "faq":questions}
91+
if len(questions) < 3:
92+
random_questions = questions
93+
else:
94+
random_questions = random.sample(questions, 3)
95+
welcome_message = {"hello":welcome.get("hello",""), "faq":random_questions}
8996
await self.send(text_data=welcome_response(message_id, welcome_message))
9097

9198
async def disconnect(self, close_code):
@@ -163,6 +170,6 @@ async def receive(self, text_data=None, bytes_data=None):
163170
if self.use_default_token and self.conversation_limit:
164171
await self.manage_quota_usage()
165172
# send stop message
166-
await self.send(text_data=stop_response(message_id, references, related_question, self.pipeline.memory_count))
173+
await self.send(text_data=stop_response(message_id, references, related_question, self.related_question_prompt, self.pipeline.memory_count))
167174

168175

kubechat/pipeline/pipeline.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ async def run(self, message, gen_references=False, message_id=""):
323323
logger.info("[%s] start processing", log_prefix)
324324

325325
references = []
326+
related_questions = []
326327
response = ""
327328
vector = self.embedding_model.embed_query(message)
328329
logger.info("[%s] embedding query end", log_prefix)
@@ -366,8 +367,11 @@ async def run(self, message, gen_references=False, message_id=""):
366367
yield self.oops
367368
need_generate_answer = False
368369
if self.welcome_question != []:
369-
yield KUBE_CHAT_RELATED_QUESTIONS + str(self.welcome_question)
370-
need_related_question = False
370+
if len(self.welcome_question)>=3:
371+
related_questions = random.sample(self.welcome_question, 3)
372+
need_related_question = False
373+
else:
374+
related_questions = self.welcome_question
371375

372376
if self.use_related_question and need_related_question:
373377
related_question_prompt = self.related_question_prompt.format(query=message, context=context)
@@ -405,18 +409,20 @@ async def run(self, message, gen_references=False, message_id=""):
405409
await self.add_ai_message(message, message_id, response, references)
406410
logger.info("[%s] add ai message end and the pipeline is succeed", log_prefix)
407411

408-
if self.use_related_question and need_related_question:
409-
related_question = await related_question_task
410-
related_question = re.sub(r'\n+', '\n', related_question).split('\n')
411-
for i,question in enumerate(related_question):
412-
match = re.match(r"\s*-\s*(.*)", question)
413-
if match:
414-
question = match.group(1)
415-
match = re.match(r"\s*\d+\.\s*(.*)", question)
416-
if match:
417-
question = match.group(1)
418-
related_question[i] = question
419-
yield KUBE_CHAT_RELATED_QUESTIONS + str(related_question[:3])
412+
if self.use_related_question:
413+
if need_related_question:
414+
related_question_generate = await related_question_task
415+
related_question = re.sub(r'\n+', '\n', related_question_generate).split('\n')
416+
for i,question in enumerate(related_question):
417+
match = re.match(r"\s*-\s*(.*)", question)
418+
if match:
419+
question = match.group(1)
420+
match = re.match(r"\s*\d+\.\s*(.*)", question)
421+
if match:
422+
question = match.group(1)
423+
related_question[i] = question
424+
related_questions.extend(related_question)
425+
yield KUBE_CHAT_RELATED_QUESTIONS + str(related_questions[:3])
420426

421427
if gen_references:
422428
yield KUBE_CHAT_DOC_QA_REFERENCES + json.dumps(references)

0 commit comments

Comments
 (0)