|
19 | 19 | from core.utils.utils import record_llm_call |
20 | 20 | logger = get_configured_logger("ranking_engine") |
21 | 21 |
|
| 22 | +# Lazy-loaded NLWebScorer singleton |
| 23 | +_nlweb_scorer = None |
| 24 | +_nlweb_scorer_is_available = None # Cached availability check |
| 25 | + |
| 26 | +def _nlweb_scorer_available(): |
| 27 | + """Check if NLWebScorer checkpoints exist and config is enabled (cached).""" |
| 28 | + global _nlweb_scorer_is_available |
| 29 | + if _nlweb_scorer_is_available is not None: |
| 30 | + return _nlweb_scorer_is_available |
| 31 | + from core.config import CONFIG |
| 32 | + scorer_config = CONFIG.nlweb.scoring.get("nlwebscorer", {}) |
| 33 | + if not scorer_config.get("enabled"): |
| 34 | + _nlweb_scorer_is_available = False |
| 35 | + return False |
| 36 | + import os |
| 37 | + nlweb_root = os.path.dirname(os.path.dirname(os.path.dirname( |
| 38 | + os.path.dirname(os.path.abspath(__file__))))) |
| 39 | + bert_path = scorer_config.get("bert_checkpoint") |
| 40 | + gam_path = scorer_config.get("gam_checkpoint") |
| 41 | + if not bert_path or not gam_path: |
| 42 | + logger.warning("NLWebScorer enabled but checkpoint paths not configured, using LLM scorer") |
| 43 | + _nlweb_scorer_is_available = False |
| 44 | + return False |
| 45 | + bert_cp = os.path.join(nlweb_root, bert_path) |
| 46 | + gam_cp = os.path.join(nlweb_root, gam_path) |
| 47 | + _nlweb_scorer_is_available = os.path.exists(bert_cp) and os.path.exists(gam_cp) |
| 48 | + if _nlweb_scorer_is_available: |
| 49 | + logger.info("NLWebScorer checkpoints found, will use as default scorer") |
| 50 | + else: |
| 51 | + logger.info("NLWebScorer checkpoints not found, using LLM scorer") |
| 52 | + return _nlweb_scorer_is_available |
| 53 | + |
| 54 | +def _get_nlweb_scorer(): |
| 55 | + """Get or create the NLWebScorer instance (lazy-loaded on first use).""" |
| 56 | + global _nlweb_scorer |
| 57 | + if _nlweb_scorer is None: |
| 58 | + from core.config import CONFIG |
| 59 | + scorer_config = CONFIG.nlweb.scoring.get("nlwebscorer", {}) |
| 60 | + |
| 61 | + import os, sys |
| 62 | + # Resolve NLWeb root: ranking.py -> core -> python -> code -> AskAgent -> NLWeb |
| 63 | + nlweb_root = os.path.dirname(os.path.dirname(os.path.dirname( |
| 64 | + os.path.dirname(os.path.abspath(__file__))))) |
| 65 | + bert_cp = os.path.join(nlweb_root, scorer_config.get("bert_checkpoint", "")) |
| 66 | + gam_cp = os.path.join(nlweb_root, scorer_config.get("gam_checkpoint", "")) |
| 67 | + |
| 68 | + scorer_dir = os.path.join(nlweb_root, "NLWebScorer") |
| 69 | + if scorer_dir not in sys.path: |
| 70 | + sys.path.append(scorer_dir) # append, not insert — NLWebScorer/config/ would shadow app's config |
| 71 | + from inference.scorer import NLWebScorer |
| 72 | + |
| 73 | + logger.info(f"Loading NLWebScorer: bert={bert_cp}, gam={gam_cp}") |
| 74 | + _nlweb_scorer = NLWebScorer( |
| 75 | + bert_checkpoint=bert_cp, |
| 76 | + gam_checkpoint=gam_cp, |
| 77 | + max_length=scorer_config.get("max_length", 1024), |
| 78 | + ) |
| 79 | + logger.info("NLWebScorer loaded successfully") |
| 80 | + return _nlweb_scorer |
| 81 | + |
22 | 82 |
|
23 | 83 | class Ranking: |
24 | 84 |
|
@@ -197,10 +257,50 @@ async def rankItem(self, url, json_str, name, site): |
197 | 257 |
|
198 | 258 | except Exception as e: |
199 | 259 | # Import here to avoid circular import |
200 | | - from config.config import CONFIG |
| 260 | + from core.config import CONFIG |
201 | 261 | if CONFIG.should_raise_exceptions(): |
202 | 262 | raise # Re-raise in testing/development mode |
203 | 263 |
|
| 264 | + async def rankItemsWithScorer(self): |
| 265 | + """Batch-score all items using NLWebScorer (no LLM calls).""" |
| 266 | + scorer = _get_nlweb_scorer() |
| 267 | + query = self.handler.decontextualized_query or self.handler.query |
| 268 | + |
| 269 | + # Build items for scorer — full schema, let BERT handle semantics |
| 270 | + scorer_items = [] |
| 271 | + for url, json_str, name, site in self.items: |
| 272 | + schema_json = json.dumps(json_str) if isinstance(json_str, dict) else json_str |
| 273 | + scorer_items.append({"name": name, "schema_json": schema_json}) |
| 274 | + |
| 275 | + results = await asyncio.to_thread(scorer.score, query, scorer_items) |
| 276 | + |
| 277 | + logger.debug(f"NLWebScorer results for: '{query}' ({len(results)} items)") |
| 278 | + debug_rows = [] |
| 279 | + for i, (url, json_str, name, site) in enumerate(self.items): |
| 280 | + score = results[i]["score"] |
| 281 | + schema_object = json_str if isinstance(json_str, dict) else json.loads(json_str) |
| 282 | + if isinstance(schema_object, list) and len(schema_object) > 0: |
| 283 | + schema_object = schema_object[0] |
| 284 | + |
| 285 | + desc = name |
| 286 | + if isinstance(schema_object, dict): |
| 287 | + desc = schema_object.get("description", schema_object.get("name", name)) |
| 288 | + if isinstance(desc, str) and len(desc) > 200: |
| 289 | + desc = desc[:200] + "..." |
| 290 | + |
| 291 | + ansr = { |
| 292 | + 'url': url, 'site': site, 'name': name, |
| 293 | + 'ranking': {"score": score, "description": desc}, |
| 294 | + 'schema_object': schema_object, 'sent': False |
| 295 | + } |
| 296 | + self.rankedAnswers.append(ansr) |
| 297 | + debug_rows.append((score, name)) |
| 298 | + |
| 299 | + debug_rows.sort(key=lambda x: x[0], reverse=True) |
| 300 | + for score, name in debug_rows: |
| 301 | + logger.debug(f" {score:3d} - {name[:70]}") |
| 302 | + logger.debug("=== end scores ===") |
| 303 | + |
204 | 304 | def shouldSend(self, result): |
205 | 305 | # Get max_results from handler, or use default |
206 | 306 | max_results = getattr(self.handler, 'max_results', self.NUM_RESULTS_TO_SEND) |
@@ -322,18 +422,34 @@ async def sendMessageOnSitesBeingAsked(self, top_embeddings): |
322 | 422 | self.handler.connection_alive_event.clear() |
323 | 423 |
|
324 | 424 | async def do(self): |
325 | | - |
326 | | - tasks = [] |
327 | | - for url, json_str, name, site in self.items: |
328 | | - if self.handler.connection_alive_event.is_set(): # Only add new tasks if connection is still alive |
329 | | - tasks.append(asyncio.create_task(self.rankItem(url, json_str, name, site))) |
330 | | - |
331 | | - # await self.sendMessageOnSitesBeingAsked(self.items) |
332 | 425 |
|
333 | | - try: |
334 | | - await asyncio.gather(*tasks, return_exceptions=True) |
335 | | - except Exception as e: |
336 | | - return |
| 426 | + # Determine scorer: auto-detect NLWebScorer if available, allow override via ?scorer=llm |
| 427 | + scorer_param = self.handler.query_params.get('scorer', [None]) |
| 428 | + scorer_override = scorer_param[0] if isinstance(scorer_param, list) else scorer_param |
| 429 | + if scorer_override == "llm": |
| 430 | + use_nlwebscorer = False |
| 431 | + elif scorer_override == "nlwebscorer": |
| 432 | + use_nlwebscorer = True |
| 433 | + else: |
| 434 | + # Auto-detect: use NLWebScorer if checkpoints exist |
| 435 | + use_nlwebscorer = _nlweb_scorer_available() |
| 436 | + |
| 437 | + if use_nlwebscorer: |
| 438 | + try: |
| 439 | + await self.rankItemsWithScorer() |
| 440 | + except Exception as e: |
| 441 | + logger.error(f"NLWebScorer scoring failed: {e}", exc_info=True) |
| 442 | + return |
| 443 | + else: |
| 444 | + tasks = [] |
| 445 | + for url, json_str, name, site in self.items: |
| 446 | + if self.handler.connection_alive_event.is_set(): |
| 447 | + tasks.append(asyncio.create_task(self.rankItem(url, json_str, name, site))) |
| 448 | + |
| 449 | + try: |
| 450 | + await asyncio.gather(*tasks, return_exceptions=True) |
| 451 | + except Exception as e: |
| 452 | + return |
337 | 453 |
|
338 | 454 | if not self.handler.connection_alive_event.is_set(): |
339 | 455 | return |
|
0 commit comments