-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfact_check_agent.py
More file actions
262 lines (235 loc) · 10.1 KB
/
fact_check_agent.py
File metadata and controls
262 lines (235 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import logging
import re
from typing import Any, List, Optional
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, Message, SystemMessage
from erniebot_agent.prompt import PromptTemplate
from tools.utils import JsonUtil, ReportCallbackHandler
logger = logging.getLogger(__name__)
PLAN_VERIFICATIONS_PROMPT = """
为了验证给出的内容中数字性表述的准确性,您需要创建一系列验证问题,
用于测试原始基线响应中的事实主张。例如,如果长格式响应的一部分包含
“墨西哥-美国战争是 1846 年至 1848 年美国和墨西哥之间的武装冲突”,
那么一种可能的验证问题可以是:“墨西哥-美国战争何时开始以及结束?”
给定内容:{{base_context}}
您需要按照列表输出,输出段落中的事实和相应的验证问题。
[{"fact": <段落中的事实>, "question": <验证问题,通过结合查询和事实生成>},
{"fact": <段落中的事实>, "question": <验证问题,通过结合查询和事实生成>}, ...]
"""
ANWSER_PROMPT = """
根据外部知识回答问题。如果给出的外部知识不能解答提出的问题,
请直接输出"无法回答",无需提供答案。给定问题:\n{{question}}\n
外部知识:{{content}}\n回答:
"""
CHECK_CLAIM_PROMPT = """
根据给定的问题和回答,判断事实中数字描述是否正确。如果认为事实中的数字描述不正确,
请根据问题和回答提供修正。您的输出应为 JSON 格式:
{"is_correct": <事实是否正确>, "modify": <对不正确的事实进行修正>}
给定问题: {{question}}\n回答: {{answer}}\n事实: {{claim}}
"""
FINAL_RESPONSE_PROMPT = """
根据提供的背景知识,对原始内容进行改写。确保改写后的内容中的数字与背景知识中的数字一致。
您必须修正原始内容中的数字。原始内容:{{origin_content}}\n背景知识:{{context}}
改写内容:
"""
class FactCheckerAgent(JsonUtil):
DEFAULT_SYSTEM_MESSAGE = "你是一个事实检查助手,你的任务就是检查文本中的事实描述是否正确"
def __init__(
self,
name: str,
llm: BaseERNIEBot,
retriever_db: Any,
system_message: Optional[SystemMessage] = None,
callbacks=None,
):
"""
Initialize a fact_checker agent.
args:
name: The name of the agent.
llm: An LLM for the agent to use.
retriever_db: A database for the agent to use.
system_message: A message to be displayed when the agent starts.
callbacks: A callback handler for the agent.
"""
self.name = name
self.llm = llm
self.retriever_db = retriever_db
self.prompt_plan_verifications = PromptTemplate(
PLAN_VERIFICATIONS_PROMPT, input_variables=["base_context"]
)
self.prompt_anwser = PromptTemplate(
ANWSER_PROMPT, input_variables=["question", "content"]
)
self.prompt_check_claim = PromptTemplate(
CHECK_CLAIM_PROMPT, input_variables=["question", "answer", "claim"]
)
self.prompt_final_response = PromptTemplate(
FINAL_RESPONSE_PROMPT, input_variables=["origin_content", "context"]
)
self.system_message = (
system_message.content
if system_message is not None
else self.DEFAULT_SYSTEM_MESSAGE
)
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
self._callback_manager = callbacks
async def run(self, report: str):
"""
The main logic of running the agent.
Args:
report: Entered report text.
Returns:
The results of the agent's operation.
"""
await self._callback_manager.on_run_start(
agent=self, agent_name=self.name, prompt=self.system_message
)
agent_resp = await self._run(report=report)
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
return agent_resp
async def generate_anwser(self, question: str, context: str):
"""
Generate answers to questions based on background knowledge
Args:
question: Indicates a question to be answered.
context: Represents background knowledge relevant to the problem.
Returns:
Generated answers to questions.
"""
messages: List[Message] = [
HumanMessage(
content=self.prompt_anwser.format(question=question, content=context)
)
]
responese = await self.llm.chat(messages)
result = responese.content
return result
async def check_claim(self, question: str, answer: str, claim: str):
"""
Use a large language model to conduct a conversation, verify a fact,
and suggest corrections if the fact is incorrect
Args:
question: represents a fact-checking question
answer: represents a fact-checking answer
claim: indicates a fact that need to be verified
Returns:
A dictionary containing verification results,
including whether the facts are correct and suggestions for correction.
"""
messages: List[Message] = [
HumanMessage(
content=self.prompt_check_claim.format(
question=question, answer=answer, claim=claim
)
)
]
responese = await self.llm.chat(messages)
result = responese.content
result = self.parse_json(result)
return result
async def verifications(self, facts_problems: List[dict]):
"""
Answer questions using external knowledge and then use the answers to the questions to verify
relevant facts. As it processes each question and fact pair, it obtains the context
relevant to the question, generates an answer to the question, checks whether the fact
is correct, and records the verification results.
Args:
facts_problems: A list of dictionaries containing questions and related facts.
Returns:
Updated dictionary list of verified questions and facts.
"""
for item in facts_problems:
question = item["question"]
claim = item["fact"]
context = self.retriever_db.search(question)
context = [i["content"] for i in context]
item["evidence"] = context
anwser = await self.generate_anwser(question, context)
item["anwser"] = anwser
result = await self.check_claim(question, anwser, claim)
item["is_correct"] = result["is_correct"]
if result["is_correct"] is False:
item["modify"] = result["modify"]
else:
item["modify"] = claim
self._callback_manager._agent_info(
msg=item["modify"], subject="事实验证的结果", state="End"
)
return facts_problems
async def generate_final_response(self, content: str, verifications: List[dict]):
"""
If the original factual questions pass fact verification,
the original content will be returned directly.
Otherwise, the original content will be corrected based
on the results of factual verification.
Args:
content: Original text content.
verifications: List of dictionaries containing fact verification results.
Returns:
The final generated reply content.
"""
if all([item["is_correct"] for item in verifications]):
return content
else:
context = "".join([item["modify"] for item in verifications])
messages: List[Message] = [
HumanMessage(
content=self.prompt_final_response.format(
origin_content=content, context=context
)
)
]
resulte = await self.llm.chat(messages)
result = resulte.content
return result
async def report_fact(self, report: str):
"""
Filter out sentences containing numbers in text.
Extract facts from the filtered sentences.
Extract validation questions and verify each extracted fact.
Example:
"Mexican-American War
was the armed conflict between the United States and Mexico from 1846 to 1848," then one possibility
A validation question to check these dates could be When did the Mexican-American War begin and end?
Args:
report: The original text content.
Returns:
The final generated reply content.
"""
report_list = report.split("\n\n")
text = []
for item in report_list:
if item.strip()[0] == "#":
text.append(item)
else:
contains_numbers = re.findall(r"\b\d+\b", item)
if contains_numbers:
messages: List[Message] = [
HumanMessage(
content=self.prompt_plan_verifications.format(
base_context=item
)
)
]
responese = await self.llm.chat(messages)
result: List[dict] = self.parse_json(responese.content)
fact_check_result: List[dict] = await self.verifications(result)
new_item: str = await self.generate_final_response(
item, fact_check_result
)
text.append(new_item)
else:
text.append(item)
return "\n\n".join(text)
async def _run(self, report: str):
"""
The main logic of running the agent.
Args:
report: Entered report text.
Returns:
Processed report text.
"""
report = await self.report_fact(report)
return report