Skip to content

Commit 8bcb845

Browse files
authored
Merge pull request #12 from SUJALGOYALL/feature/newsqa-improved-answering
feat: enhance NewsQAAgent with relevance scoring and question type detection
2 parents 1935d04 + 3d6be79 commit 8bcb845

File tree

2 files changed

+292
-10
lines changed

2 files changed

+292
-10
lines changed
Lines changed: 222 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,238 @@
1+
# from app.config.adk_config import AGENT_CONFIGS
2+
3+
# QA_INSTRUCTION = """
4+
# You are the News QA Agent. Answer user questions using the news corpus. Be concise and cite the article title in your answer. If no answer is found, say 'No relevant article found.'
5+
# """
6+
7+
# def create_news_qa_agent():
8+
# config = AGENT_CONFIGS["news_qa_agent"]
9+
# class NewsQAAgent:
10+
# def __init__(self):
11+
# self.name = config["name"]
12+
# self.model = config["model"]
13+
# self.description = config["description"]
14+
# self.instruction = QA_INSTRUCTION
15+
# self.tools = []
16+
# def answer(self, articles, question):
17+
# keywords = question.lower().split() if question else []
18+
# for article in articles:
19+
# content = article.get('content', '').lower()
20+
# if any(word in content for word in keywords):
21+
# return {
22+
# 'title': article.get('title', ''),
23+
# 'answer': article.get('content', '')
24+
# }
25+
# return {'answer': 'No relevant article found.'}
26+
# return NewsQAAgent()
27+
28+
# app/adk/agents/news_qa_agent.py
29+
130
from app.config.adk_config import AGENT_CONFIGS
231

332
QA_INSTRUCTION = """
4-
You are the News QA Agent. Answer user questions using the news corpus. Be concise and cite the article title in your answer. If no answer is found, say 'No relevant article found.'
33+
You are the News QA Agent. Answer user questions using the provided news corpus.
34+
Be concise and cite the article titles in your answer.
35+
If no relevant answer is found, reply with 'No relevant article found.'
536
"""
637

738
def create_news_qa_agent():
839
config = AGENT_CONFIGS["news_qa_agent"]
40+
941
class NewsQAAgent:
1042
def __init__(self):
1143
self.name = config["name"]
1244
self.model = config["model"]
1345
self.description = config["description"]
1446
self.instruction = QA_INSTRUCTION
1547
self.tools = []
16-
def answer(self, articles, question):
17-
keywords = question.lower().split() if question else []
48+
49+
# ------------------------------------------------------------
50+
# 1️⃣ Relevance Scoring
51+
# ------------------------------------------------------------
52+
def _calculate_relevance_score(self, article, question):
53+
"""Calculate a simple relevance score between article and question."""
54+
content = article.get("content", "").lower()
55+
title = article.get("title", "").lower()
56+
question_lower = question.lower()
57+
58+
score = 0
59+
question_words = set(question_lower.split())
60+
content_words = set(content.split())
61+
title_words = set(title.split())
62+
63+
# Weighted keyword overlap
64+
title_matches = len(question_words.intersection(title_words))
65+
score += title_matches * 3 # Title matches are more important
66+
67+
content_matches = len(question_words.intersection(content_words))
68+
score += content_matches * 1
69+
70+
# Phrase (word presence) bonus
71+
for word in question_words:
72+
if word in content:
73+
score += 0.5
74+
75+
return score
76+
77+
def _rank_articles_by_relevance(self, articles, question):
78+
"""Rank articles by calculated relevance score."""
79+
scored_articles = []
1880
for article in articles:
19-
content = article.get('content', '').lower()
20-
if any(word in content for word in keywords):
21-
return {
22-
'title': article.get('title', ''),
23-
'answer': article.get('content', '')
24-
}
25-
return {'answer': 'No relevant article found.'}
81+
score = self._calculate_relevance_score(article, question)
82+
if score > 0:
83+
scored_articles.append((article, score))
84+
85+
scored_articles.sort(key=lambda x: x[1], reverse=True)
86+
return [article for article, score in scored_articles]
87+
88+
# ------------------------------------------------------------
89+
# 2️⃣ Answer Extraction
90+
# ------------------------------------------------------------
91+
def _extract_relevant_excerpts(self, article, question, max_length=200):
92+
"""Extract the most relevant excerpts from article content."""
93+
content = article.get("content", "")
94+
question_words = set(question.lower().split())
95+
96+
# Split into sentences
97+
sentences = content.split(". ")
98+
99+
scored_sentences = []
100+
for sentence in sentences:
101+
sentence_lower = sentence.lower()
102+
matches = sum(1 for word in question_words if word in sentence_lower)
103+
if matches > 0:
104+
scored_sentences.append((sentence, matches))
105+
106+
# Sort by relevance
107+
scored_sentences.sort(key=lambda x: x[1], reverse=True)
108+
109+
excerpt = ""
110+
for sentence, _ in scored_sentences:
111+
if len(excerpt + sentence) < max_length:
112+
excerpt += sentence + ". "
113+
else:
114+
break
115+
116+
return excerpt.strip()
117+
118+
# ------------------------------------------------------------
119+
# 3️⃣ Question Type Detection
120+
# ------------------------------------------------------------
121+
def _detect_question_type(self, question):
122+
"""Detect the type of question being asked."""
123+
question_lower = question.lower()
124+
125+
# Temporal questions FIRST
126+
temporal_keywords = ["trend", "change", "over time", "recently", "history", "evolution", "growth", "decline", "past", "has the", "has been"]
127+
if any(word in question_lower for word in temporal_keywords):
128+
return "temporal"
129+
130+
# Comparative questions
131+
if any(word in question_lower for word in ["compare", "difference", "vs", "versus"]):
132+
return "comparative"
133+
134+
# Causal questions
135+
if any(word in question_lower for word in ["impact", "effect", "consequence", "cause"]):
136+
return "causal"
137+
138+
# Factual/general questions LAST
139+
if any(word in question_lower for word in ["what", "how", "why", "when", "where"]):
140+
return "factual"
141+
142+
return "general"
143+
144+
145+
# ------------------------------------------------------------
146+
# 4️⃣ Handling Question Types
147+
# ------------------------------------------------------------
148+
def _handle_question_type(self, question, articles):
149+
"""Route handling logic based on question type."""
150+
question_type = self._detect_question_type(question)
151+
152+
if question_type == "comparative":
153+
return self._handle_comparative_question(question, articles)
154+
elif question_type == "temporal":
155+
return self._handle_temporal_question(question, articles)
156+
elif question_type == "causal":
157+
return self._handle_causal_question(question, articles)
158+
else:
159+
return self._handle_general_question(question, articles)
160+
161+
def _handle_general_question(self, question, articles):
162+
"""Default handling for factual/general questions."""
163+
return self._generate_answer(articles, question)
164+
165+
def _handle_comparative_question(self, question, articles):
166+
"""Compare information across sources."""
167+
return self._generate_answer(articles, question)
168+
169+
def _handle_temporal_question(self, question, articles):
170+
"""Handle time/trend-based questions."""
171+
return self._generate_answer(articles, question)
172+
173+
def _handle_causal_question(self, question, articles):
174+
"""Handle cause-effect questions."""
175+
return self._generate_answer(articles, question)
176+
177+
# ------------------------------------------------------------
178+
# 5️⃣ Generate Final Answer
179+
# ------------------------------------------------------------
180+
def _generate_answer(self, articles, question):
181+
"""Generate a grounded answer using multiple top articles."""
182+
if not articles:
183+
return "No relevant articles found."
184+
185+
top_articles = articles[:3]
186+
answers = []
187+
188+
for i, article in enumerate(top_articles):
189+
excerpt = self._extract_relevant_excerpts(article, question)
190+
if excerpt:
191+
source = f"Source {i+1}: {article.get('title', 'Untitled')}"
192+
answers.append(f"{source}\n{excerpt}")
193+
194+
if not answers:
195+
return "No relevant information found in the articles."
196+
197+
return "\n\n".join(answers)
198+
199+
# ------------------------------------------------------------
200+
# 6️⃣ Enhanced Public API
201+
# ------------------------------------------------------------
202+
def answer(self, articles, question):
203+
"""Enhanced answer method with scoring, extraction, and citations."""
204+
if not articles or not question:
205+
return {"answer": "No articles or question provided."}
206+
207+
relevant_articles = self._rank_articles_by_relevance(articles, question)
208+
209+
if not relevant_articles:
210+
return {"answer": "No relevant articles found for this question."}
211+
212+
answer_text = self._handle_question_type(question, relevant_articles)
213+
sources = [article.get("title", "Untitled") for article in relevant_articles[:3]]
214+
215+
return {
216+
"answer": answer_text,
217+
"sources": sources,
218+
"relevance_score": self._calculate_relevance_score(relevant_articles[0], question),
219+
"question_type": self._detect_question_type(question),
220+
}
221+
26222
return NewsQAAgent()
223+
224+
225+
226+
agent = create_news_qa_agent()
227+
articles = [
228+
{
229+
"title": "Apple Reports Strong Earnings",
230+
"content": "Apple Inc. reported strong quarterly earnings, beating analyst expectations."
231+
},
232+
{
233+
"title": "Market Update",
234+
"content": "The stock market showed mixed signals today."
235+
}
236+
]
237+
result = agent.answer(articles, "What did Apple report about earnings?")
238+
print(result)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from app.adk.agents.news_qa_agent import create_news_qa_agent
3+
4+
def test_relevance_scoring():
5+
"""Test relevance scoring functionality"""
6+
agent = create_news_qa_agent()
7+
8+
articles = [
9+
{
10+
'title': 'Apple Reports Strong Earnings',
11+
'content': 'Apple Inc. reported strong quarterly earnings, beating analyst expectations.'
12+
},
13+
{
14+
'title': 'Market Update',
15+
'content': 'The stock market showed mixed signals today.'
16+
}
17+
]
18+
19+
question = "What did Apple report about earnings?"
20+
21+
score1 = agent._calculate_relevance_score(articles[0], question)
22+
score2 = agent._calculate_relevance_score(articles[1], question)
23+
24+
assert score1 > score2, "Apple article should have higher relevance"
25+
26+
def test_answer_extraction():
27+
"""Test answer extraction from articles"""
28+
agent = create_news_qa_agent()
29+
30+
article = {
31+
'title': 'Apple Reports Strong Earnings',
32+
'content': (
33+
'Apple Inc. reported strong quarterly earnings, beating analyst expectations. '
34+
'The company saw significant growth in iPhone sales. Revenue increased by 15% compared to last year.'
35+
)
36+
}
37+
38+
question = "What did Apple report about earnings?"
39+
excerpt = agent._extract_relevant_excerpts(article, question)
40+
41+
assert 'earnings' in excerpt.lower(), "Excerpt should mention earnings"
42+
assert len(excerpt) < 200, "Excerpt should be concise (<200 chars)"
43+
44+
def test_question_type_detection():
45+
"""Test question type detection"""
46+
agent = create_news_qa_agent()
47+
48+
assert agent._detect_question_type("What is Apple's revenue?") == 'factual'
49+
assert agent._detect_question_type("Compare Apple and Microsoft") == 'comparative'
50+
assert agent._detect_question_type("How has the market changed recently?") == 'temporal'
51+
52+
def test_enhanced_answer():
53+
"""Test enhanced answer method"""
54+
agent = create_news_qa_agent()
55+
56+
articles = [
57+
{
58+
'title': 'Apple Reports Strong Earnings',
59+
'content': 'Apple Inc. reported strong quarterly earnings, beating analyst expectations.'
60+
}
61+
]
62+
63+
question = "What did Apple report about earnings?"
64+
result = agent.answer(articles, question)
65+
66+
assert 'answer' in result
67+
assert 'sources' in result
68+
assert 'relevance_score' in result
69+
assert 'question_type' in result
70+
assert 'earnings' in result['answer'].lower()

0 commit comments

Comments
 (0)