Skip to content

Commit 80b3758

Browse files
authored
Merge pull request #413 from nlweb-ai/add-datafinder
v0.55 protocol, NLWebScorer integration, scorer UI cleanup
2 parents b5487c5 + 9226bac commit 80b3758

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+7939
-95
lines changed

.gitignore

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ myenv
1313
venv
1414
.env
1515
data/
16+
!NLWebScorer/data/
17+
NLWebScorer/data/*.json
18+
NLWebScorer/data/*.jsonl
19+
NLWebScorer/data/**/*.json
20+
NLWebScorer/data/**/*.jsonl
21+
NLWebScorer/data/prepared*/
22+
NLWebScorer/data/holdout*/
1623

1724

1825
# User-specific files (MonoDevelop/Xamarin Studio)
@@ -458,4 +465,9 @@ static/.DS_Store
458465
nlwm_deploy_*.zip
459466
agentfinder_deploy_*.zip
460467

461-
openai-apps-sdk-examples/
468+
openai-apps-sdk-examples/
469+
470+
# NLWebScorer - exclude checkpoints and training logs
471+
NLWebScorer/checkpoints/
472+
NLWebScorer/*.log
473+
AskAgent/set_keys.sh

AskAgent/python/core/baseHandler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def __init__(self, query_params, http_handler):
115115
# Maximum number of results to return to the user
116116
self.max_results = get_param(query_params, "max_results", int, 10)
117117

118+
# Protocol version (v0.55 when using structured POST body)
119+
self.protocol_version = query_params.get('_protocol_version')
120+
118121
# the items that have been retrieved from the vector database, could be before decontextualization.
119122
# See below notes on fasttrack
120123
self.retrieved_items = []
@@ -413,8 +416,8 @@ async def prepare(self):
413416
)
414417
self.final_retrieved_items = items
415418
self.retrieval_done_event.set()
416-
417-
logger.info("Preparation phase completed")
419+
420+
logger.info(f"Preparation phase completed. Retrieved {len(self.final_retrieved_items)} items.")
418421

419422
def decontextualizeQuery(self):
420423
if (len(self.prev_queries) < 1):

AskAgent/python/core/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class NLWebConfig:
102102
who_endpoint_enabled: bool = True # Enable or disable the who endpoint
103103
api_keys: Dict[str, str] = field(default_factory=dict) # API keys for external services
104104
who_endpoint: str = "http://localhost:8000/who" # Endpoint for /who requests
105+
scoring: Dict[str, Any] = field(default_factory=dict) # Scoring configuration (e.g. nlwebscorer)
105106

106107
@dataclass
107108
class ConversationStorageConfig:
@@ -487,7 +488,10 @@ def load_nlweb_config(self, path: str = "config_nlweb.yaml"):
487488

488489
# Load who_endpoint from config
489490
who_endpoint = self._get_config_value(data.get("who_endpoint"), "http://localhost:8000/who")
490-
491+
492+
# Load scoring configuration
493+
scoring = data.get("scoring", {})
494+
491495
# Load headers from config
492496
headers = data.get("headers", {})
493497

@@ -525,7 +529,8 @@ def load_nlweb_config(self, path: str = "config_nlweb.yaml"):
525529
aggregation_enabled=aggregation_enabled,
526530
who_endpoint_enabled=who_endpoint_enabled,
527531
api_keys=api_keys,
528-
who_endpoint=who_endpoint
532+
who_endpoint=who_endpoint,
533+
scoring=scoring
529534
)
530535

531536
def get_chatbot_instructions(self, instruction_type: str = "search_results") -> str:

AskAgent/python/core/ranking.py

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,66 @@
1919
from core.utils.utils import record_llm_call
2020
logger = get_configured_logger("ranking_engine")
2121

22+
# Lazy-loaded NLWebScorer singleton
23+
_nlweb_scorer = None
24+
_nlweb_scorer_is_available = None # Cached availability check
25+
26+
def _nlweb_scorer_available():
27+
"""Check if NLWebScorer checkpoints exist and config is enabled (cached)."""
28+
global _nlweb_scorer_is_available
29+
if _nlweb_scorer_is_available is not None:
30+
return _nlweb_scorer_is_available
31+
from core.config import CONFIG
32+
scorer_config = CONFIG.nlweb.scoring.get("nlwebscorer", {})
33+
if not scorer_config.get("enabled"):
34+
_nlweb_scorer_is_available = False
35+
return False
36+
import os
37+
nlweb_root = os.path.dirname(os.path.dirname(os.path.dirname(
38+
os.path.dirname(os.path.abspath(__file__)))))
39+
bert_path = scorer_config.get("bert_checkpoint")
40+
gam_path = scorer_config.get("gam_checkpoint")
41+
if not bert_path or not gam_path:
42+
logger.warning("NLWebScorer enabled but checkpoint paths not configured, using LLM scorer")
43+
_nlweb_scorer_is_available = False
44+
return False
45+
bert_cp = os.path.join(nlweb_root, bert_path)
46+
gam_cp = os.path.join(nlweb_root, gam_path)
47+
_nlweb_scorer_is_available = os.path.exists(bert_cp) and os.path.exists(gam_cp)
48+
if _nlweb_scorer_is_available:
49+
logger.info("NLWebScorer checkpoints found, will use as default scorer")
50+
else:
51+
logger.info("NLWebScorer checkpoints not found, using LLM scorer")
52+
return _nlweb_scorer_is_available
53+
54+
def _get_nlweb_scorer():
55+
"""Get or create the NLWebScorer instance (lazy-loaded on first use)."""
56+
global _nlweb_scorer
57+
if _nlweb_scorer is None:
58+
from core.config import CONFIG
59+
scorer_config = CONFIG.nlweb.scoring.get("nlwebscorer", {})
60+
61+
import os, sys
62+
# Resolve NLWeb root: ranking.py -> core -> python -> code -> AskAgent -> NLWeb
63+
nlweb_root = os.path.dirname(os.path.dirname(os.path.dirname(
64+
os.path.dirname(os.path.abspath(__file__)))))
65+
bert_cp = os.path.join(nlweb_root, scorer_config.get("bert_checkpoint", ""))
66+
gam_cp = os.path.join(nlweb_root, scorer_config.get("gam_checkpoint", ""))
67+
68+
scorer_dir = os.path.join(nlweb_root, "NLWebScorer")
69+
if scorer_dir not in sys.path:
70+
sys.path.append(scorer_dir) # append, not insert — NLWebScorer/config/ would shadow app's config
71+
from inference.scorer import NLWebScorer
72+
73+
logger.info(f"Loading NLWebScorer: bert={bert_cp}, gam={gam_cp}")
74+
_nlweb_scorer = NLWebScorer(
75+
bert_checkpoint=bert_cp,
76+
gam_checkpoint=gam_cp,
77+
max_length=scorer_config.get("max_length", 1024),
78+
)
79+
logger.info("NLWebScorer loaded successfully")
80+
return _nlweb_scorer
81+
2282

2383
class Ranking:
2484

@@ -197,10 +257,50 @@ async def rankItem(self, url, json_str, name, site):
197257

198258
except Exception as e:
199259
# Import here to avoid circular import
200-
from config.config import CONFIG
260+
from core.config import CONFIG
201261
if CONFIG.should_raise_exceptions():
202262
raise # Re-raise in testing/development mode
203263

264+
async def rankItemsWithScorer(self):
265+
"""Batch-score all items using NLWebScorer (no LLM calls)."""
266+
scorer = _get_nlweb_scorer()
267+
query = self.handler.decontextualized_query or self.handler.query
268+
269+
# Build items for scorer — full schema, let BERT handle semantics
270+
scorer_items = []
271+
for url, json_str, name, site in self.items:
272+
schema_json = json.dumps(json_str) if isinstance(json_str, dict) else json_str
273+
scorer_items.append({"name": name, "schema_json": schema_json})
274+
275+
results = await asyncio.to_thread(scorer.score, query, scorer_items)
276+
277+
logger.debug(f"NLWebScorer results for: '{query}' ({len(results)} items)")
278+
debug_rows = []
279+
for i, (url, json_str, name, site) in enumerate(self.items):
280+
score = results[i]["score"]
281+
schema_object = json_str if isinstance(json_str, dict) else json.loads(json_str)
282+
if isinstance(schema_object, list) and len(schema_object) > 0:
283+
schema_object = schema_object[0]
284+
285+
desc = name
286+
if isinstance(schema_object, dict):
287+
desc = schema_object.get("description", schema_object.get("name", name))
288+
if isinstance(desc, str) and len(desc) > 200:
289+
desc = desc[:200] + "..."
290+
291+
ansr = {
292+
'url': url, 'site': site, 'name': name,
293+
'ranking': {"score": score, "description": desc},
294+
'schema_object': schema_object, 'sent': False
295+
}
296+
self.rankedAnswers.append(ansr)
297+
debug_rows.append((score, name))
298+
299+
debug_rows.sort(key=lambda x: x[0], reverse=True)
300+
for score, name in debug_rows:
301+
logger.debug(f" {score:3d} - {name[:70]}")
302+
logger.debug("=== end scores ===")
303+
204304
def shouldSend(self, result):
205305
# Get max_results from handler, or use default
206306
max_results = getattr(self.handler, 'max_results', self.NUM_RESULTS_TO_SEND)
@@ -322,18 +422,34 @@ async def sendMessageOnSitesBeingAsked(self, top_embeddings):
322422
self.handler.connection_alive_event.clear()
323423

324424
async def do(self):
325-
326-
tasks = []
327-
for url, json_str, name, site in self.items:
328-
if self.handler.connection_alive_event.is_set(): # Only add new tasks if connection is still alive
329-
tasks.append(asyncio.create_task(self.rankItem(url, json_str, name, site)))
330-
331-
# await self.sendMessageOnSitesBeingAsked(self.items)
332425

333-
try:
334-
await asyncio.gather(*tasks, return_exceptions=True)
335-
except Exception as e:
336-
return
426+
# Determine scorer: auto-detect NLWebScorer if available, allow override via ?scorer=llm
427+
scorer_param = self.handler.query_params.get('scorer', [None])
428+
scorer_override = scorer_param[0] if isinstance(scorer_param, list) else scorer_param
429+
if scorer_override == "llm":
430+
use_nlwebscorer = False
431+
elif scorer_override == "nlwebscorer":
432+
use_nlwebscorer = True
433+
else:
434+
# Auto-detect: use NLWebScorer if checkpoints exist
435+
use_nlwebscorer = _nlweb_scorer_available()
436+
437+
if use_nlwebscorer:
438+
try:
439+
await self.rankItemsWithScorer()
440+
except Exception as e:
441+
logger.error(f"NLWebScorer scoring failed: {e}", exc_info=True)
442+
return
443+
else:
444+
tasks = []
445+
for url, json_str, name, site in self.items:
446+
if self.handler.connection_alive_event.is_set():
447+
tasks.append(asyncio.create_task(self.rankItem(url, json_str, name, site)))
448+
449+
try:
450+
await asyncio.gather(*tasks, return_exceptions=True)
451+
except Exception as e:
452+
return
337453

338454
if not self.handler.connection_alive_event.is_set():
339455
return

0 commit comments

Comments
 (0)