Skip to content

Commit ffadeb5

Browse files
committed
support file
1 parent 6604b9b commit ffadeb5

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

kubechat/chat/websocket/common_consumer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
class CommonConsumer(BaseConsumer):
2020
async def connect(self):
2121
await super().connect()
22+
self.file = None
23+
self.file_name = None
2224
self.pipeline = CommonPipeline(bot=self.bot, collection=self.collection, history=self.history)
2325
self.use_default_token = self.pipeline.predictor.use_default_token
2426

@@ -27,16 +29,12 @@ async def predict(self, query, **kwargs):
2729
yield msg
2830

2931
async def receive(self, text_data=None, bytes_data=None):
30-
data = json.loads(text_data)
31-
self.msg_type = data["type"]
32-
self.response_type = "message"
33-
3432
message = ""
3533
message_id = f"{now_unix_milliseconds()}"
3634

37-
self.file = None
35+
# 先在text_data里传file_name,再在bytes_data里传file_content,最后问问题
3836
if bytes_data:
39-
file_name = data["file_name"]
37+
file_name = self.file_name
4038
file_suffix = os.path.splitext(file_name)[1].lower()
4139
if file_suffix not in DEFAULT_FILE_READER_CLS.keys():
4240
error = f"unsupported file type {file_suffix}"
@@ -50,6 +48,19 @@ async def receive(self, text_data=None, bytes_data=None):
5048
reader = DEFAULT_FILE_READER_CLS[file_suffix]
5149
docs = reader.load_data(temp_file.name)
5250
self.file = docs[0].text
51+
return
52+
53+
data = json.loads(text_data)
54+
self.msg_type = data["type"]
55+
self.response_type = "message"
56+
57+
if self.msg_type == "file_upload":
58+
self.file_name = data["file_name"]
59+
return
60+
elif self.msg_type == "cancel_upload":
61+
self.file = None
62+
self.file_name = None
63+
return
5364

5465
try:
5566
# send start message
@@ -79,6 +90,8 @@ async def receive(self, text_data=None, bytes_data=None):
7990
logger.warning("[Oops] %s: %s", str(e), traceback.format_exc())
8091
await self.send(text_data=fail_response(message_id, str(e)))
8192
finally:
93+
self.file = None
94+
self.file_name = None
8295
if self.use_default_token and self.conversation_limit:
8396
await self.manage_quota_usage()
8497
# send stop message

kubechat/pipeline/base_pipeline.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,11 @@ def __init__(self,
6262
self.use_related_question = bot_config.get("use_related_question", False)
6363

6464
welcome = bot_config.get("welcome", {})
65-
if welcome:
66-
faq = welcome.get("faq", [])
67-
self.welcome_question = []
68-
for qa in faq:
69-
self.welcome_question.append(qa["question"])
70-
self.oops = welcome.get("oops", "")
65+
faq = welcome.get("faq", [])
66+
self.welcome_question = []
67+
for qa in faq:
68+
self.welcome_question.append(qa["question"])
69+
self.oops = welcome.get("oops", "")
7170

7271
if self.memory:
7372
self.prompt_template = self.llm_config.get("memory_prompt_template", None)

0 commit comments

Comments
 (0)