Skip to content

Commit f41df8e

Browse files
fix: share ML model instances to reduce startup time
The backend startup was slow because RetrieverTools.initialize() creates 6 retriever chains, and each one independently loaded its own copy of the embedding model (thenlper/gte-large) and reranker model (BAAI/bge-reranker-base). That meant 12 heavy model loads when only 2 are actually needed, since all chains use the same model config. This fix creates both models once at the top of initialize() and passes the shared instances down through HybridRetrieverChain, SimilarityRetrieverChain, and FAISSVectorDatabase. Both models are stateless (they only run encode/score inference) so sharing a single instance across all chains is safe. Each chain still builds its own independent FAISS index with its own documents. Startup model loading goes from ~34s to ~7s on a local machine (4.9x). Resolves #88 Signed-off-by: Harsh Kumar <harshkumar3446@gmail.com>
1 parent decfd26 commit f41df8e

File tree

6 files changed

+134
-12
lines changed

6 files changed

+134
-12
lines changed

backend/src/agents/retriever_tools.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import os
2+
import logging
23
from typing import Tuple, Optional, Union
34
from dotenv import load_dotenv
45

56
from langchain_core.tools import tool
67
from langchain.retrievers import EnsembleRetriever
78
from langchain.retrievers import ContextualCompressionRetriever
9+
from langchain_huggingface import HuggingFaceEmbeddings
10+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
11+
from langchain_google_vertexai import VertexAIEmbeddings
12+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
813

914
from ..chains.hybrid_retriever_chain import HybridRetrieverChain
1015
from ..tools.format_docs import format_docs
@@ -39,13 +44,49 @@ def __init__(self) -> None:
3944
]
4045
tool_descriptions: str = ""
4146

47+
@staticmethod
48+
def _create_embedding_model(
49+
embeddings_config: dict[str, str],
50+
use_cuda: bool = False,
51+
) -> Union[HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings]:
52+
embeddings_type = embeddings_config["type"]
53+
embeddings_model_name = embeddings_config["name"]
54+
55+
if embeddings_type == "GOOGLE_GENAI":
56+
logging.info("Using Google GenerativeAI embeddings...")
57+
return GoogleGenerativeAIEmbeddings(
58+
model=embeddings_model_name,
59+
task_type="retrieval_document",
60+
)
61+
elif embeddings_type == "GOOGLE_VERTEXAI":
62+
logging.info("Using Google VertexAI embeddings...")
63+
return VertexAIEmbeddings(model_name=embeddings_model_name)
64+
elif embeddings_type == "HF":
65+
logging.info("Using HuggingFace embeddings...")
66+
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
67+
return HuggingFaceEmbeddings(
68+
model_name=embeddings_model_name,
69+
multi_process=False,
70+
encode_kwargs={"normalize_embeddings": True},
71+
model_kwargs=model_kwargs,
72+
)
73+
else:
74+
raise ValueError("Invalid embeddings type specified.")
75+
4276
def initialize(
4377
self,
4478
embeddings_config: dict[str, str],
4579
reranking_model_name: str,
4680
use_cuda: bool = False,
4781
fast_mode: bool = False,
4882
) -> None:
83+
# Create shared model instances once
84+
embedding_model = self._create_embedding_model(embeddings_config, use_cuda)
85+
logging.info("Shared embedding model created.")
86+
87+
reranker_model = HuggingFaceCrossEncoder(model_name=reranking_model_name)
88+
logging.info("Shared reranker model created.")
89+
4990
markdown_docs_map = {
5091
"general": [
5192
"./data/markdown/OR_docs",
@@ -100,6 +141,8 @@ def initialize(
100141
contextual_rerank=True,
101142
search_k=search_k,
102143
chunk_size=chunk_size,
144+
embedding_model=embedding_model,
145+
reranker_model=reranker_model,
103146
)
104147
general_retriever_chain.create_hybrid_retriever()
105148
RetrieverTools.general_retriever = general_retriever_chain.retriever
@@ -115,6 +158,8 @@ def initialize(
115158
contextual_rerank=True,
116159
search_k=search_k,
117160
chunk_size=chunk_size,
161+
embedding_model=embedding_model,
162+
reranker_model=reranker_model,
118163
)
119164
install_retriever_chain.create_hybrid_retriever()
120165
RetrieverTools.install_retriever = install_retriever_chain.retriever
@@ -131,6 +176,8 @@ def initialize(
131176
contextual_rerank=True,
132177
search_k=search_k,
133178
chunk_size=chunk_size,
179+
embedding_model=embedding_model,
180+
reranker_model=reranker_model,
134181
)
135182
commands_retriever_chain.create_hybrid_retriever()
136183
RetrieverTools.commands_retriever = commands_retriever_chain.retriever
@@ -146,6 +193,8 @@ def initialize(
146193
contextual_rerank=True,
147194
search_k=search_k,
148195
chunk_size=chunk_size,
196+
embedding_model=embedding_model,
197+
reranker_model=reranker_model,
149198
)
150199
yosys_rtdocs_retriever_chain.create_hybrid_retriever()
151200
RetrieverTools.yosys_rtdocs_retriever = yosys_rtdocs_retriever_chain.retriever
@@ -161,6 +210,8 @@ def initialize(
161210
contextual_rerank=True,
162211
search_k=search_k,
163212
chunk_size=chunk_size,
213+
embedding_model=embedding_model,
214+
reranker_model=reranker_model,
164215
)
165216
klayout_retriever_chain.create_hybrid_retriever()
166217
RetrieverTools.klayout_retriever = klayout_retriever_chain.retriever
@@ -176,6 +227,8 @@ def initialize(
176227
contextual_rerank=True,
177228
search_k=search_k,
178229
chunk_size=chunk_size,
230+
embedding_model=embedding_model,
231+
reranker_model=reranker_model,
179232
)
180233
errinfo_retriever_chain.create_hybrid_retriever()
181234
RetrieverTools.errinfo_retriever = errinfo_retriever_chain.retriever

backend/src/chains/hybrid_retriever_chain.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(
3838
weights: list[float] = [0.33, 0.33, 0.33],
3939
chunk_size: int = 500,
4040
contextual_rerank: bool = False,
41+
embedding_model=None,
42+
reranker_model: Optional[HuggingFaceCrossEncoder] = None,
4143
):
4244
super().__init__(
4345
llm_model=llm_model,
@@ -48,6 +50,8 @@ def __init__(
4850

4951
self.reranking_model_name: Optional[str] = reranking_model_name
5052
self.use_cuda: bool = use_cuda
53+
self.embedding_model = embedding_model
54+
self.reranker_model = reranker_model
5155

5256
self.search_k: int = search_k
5357
self.weights: list[float] = weights
@@ -74,6 +78,7 @@ def create_hybrid_retriever(self) -> None:
7478
html_docs_path=self.html_docs_path,
7579
chunk_size=self.chunk_size,
7680
use_cuda=self.use_cuda,
81+
embedding_model=self.embedding_model,
7782
)
7883
if self.vector_db is None:
7984
cur_path = os.path.abspath(__file__)
@@ -121,8 +126,11 @@ def create_hybrid_retriever(self) -> None:
121126
)
122127

123128
if self.contextual_rerank:
129+
reranker = self.reranker_model or HuggingFaceCrossEncoder(
130+
model_name=self.reranking_model_name
131+
)
124132
compressor = CrossEncoderReranker(
125-
model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name),
133+
model=reranker,
126134
top_n=self.search_k,
127135
)
128136
self.retriever = ContextualCompressionRetriever(

backend/src/chains/similarity_retriever_chain.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
embeddings_config: Optional[dict[str, str]] = None,
2929
use_cuda: bool = False,
3030
chunk_size: int = 500,
31+
embedding_model=None,
3132
):
3233
super().__init__(
3334
llm_model=llm_model,
@@ -40,6 +41,7 @@ def __init__(
4041

4142
self.embeddings_config: Optional[dict[str, str]] = embeddings_config
4243
self.use_cuda: bool = use_cuda
44+
self.embedding_model = embedding_model
4345

4446
self.markdown_docs_path: Optional[list[str]] = markdown_docs_path
4547
self.other_docs_path: Optional[list[str]] = other_docs_path
@@ -125,6 +127,7 @@ def create_vector_db(self) -> None:
125127
embeddings_model_name=self.embeddings_config["name"],
126128
embeddings_type=self.embeddings_config["type"],
127129
use_cuda=self.use_cuda,
130+
embedding_model=self.embedding_model,
128131
)
129132
else:
130133
raise ValueError("Embeddings model config not provided correctly.")

backend/src/vectorstores/faiss.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,21 @@ def __init__(
2828
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
2929
debug: bool = False,
3030
use_cuda: bool = False,
31+
embedding_model: Optional[
32+
Union[
33+
HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings
34+
]
35+
] = None,
3136
):
3237
self.embeddings_model_name = embeddings_model_name
3338

34-
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
35-
3639
self.embedding_model: Union[
3740
HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings
3841
]
3942

40-
if embeddings_type == "GOOGLE_GENAI":
43+
if embedding_model is not None:
44+
self.embedding_model = embedding_model
45+
elif embeddings_type == "GOOGLE_GENAI":
4146
self.embedding_model = GoogleGenerativeAIEmbeddings(
4247
model=self.embeddings_model_name,
4348
task_type="retrieval_document",
@@ -51,6 +56,7 @@ def __init__(
5156
logging.info("Using Google VertexAI embeddings...")
5257

5358
elif embeddings_type == "HF":
59+
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
5460
self.embedding_model = HuggingFaceEmbeddings(
5561
model_name=self.embeddings_model_name,
5662
multi_process=False,

backend/tests/test_retriever_tools.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,18 @@ def test_init(self):
1414
# Check that it's a valid instance
1515
assert isinstance(tools, RetrieverTools)
1616

17+
@patch("src.agents.retriever_tools.HuggingFaceCrossEncoder")
18+
@patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model")
1719
@patch("src.agents.retriever_tools.HybridRetrieverChain")
18-
def test_initialize_success(self, mock_hybrid_chain):
20+
def test_initialize_success(
21+
self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder
22+
):
1923
"""Test successful initialization of all retrievers."""
2024
tools = RetrieverTools()
2125

26+
mock_create_embed.return_value = Mock()
27+
mock_cross_encoder.return_value = Mock()
28+
2229
# Mock the HybridRetrieverChain instances
2330
mock_chains = []
2431
for i in range(
@@ -55,11 +62,18 @@ def test_initialize_success(self, mock_hybrid_chain):
5562
assert RetrieverTools.klayout_retriever == mock_chains[4].retriever
5663
assert RetrieverTools.errinfo_retriever == mock_chains[5].retriever
5764

65+
@patch("src.agents.retriever_tools.HuggingFaceCrossEncoder")
66+
@patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model")
5867
@patch("src.agents.retriever_tools.HybridRetrieverChain")
59-
def test_initialize_with_fast_mode(self, mock_hybrid_chain):
68+
def test_initialize_with_fast_mode(
69+
self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder
70+
):
6071
"""Test initialization with fast mode enabled."""
6172
tools = RetrieverTools()
6273

74+
mock_create_embed.return_value = Mock()
75+
mock_cross_encoder.return_value = Mock()
76+
6377
# Mock the HybridRetrieverChain instances
6478
mock_chains = []
6579
for i in range(6):
@@ -250,11 +264,18 @@ def test_retrieve_klayout_docs_not_initialized(self):
250264
with pytest.raises(ValueError, match="KLayout Retriever not initialized"):
251265
RetrieverTools.retrieve_klayout_docs.invoke(input="test query")
252266

267+
@patch("src.agents.retriever_tools.HuggingFaceCrossEncoder")
268+
@patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model")
253269
@patch("src.agents.retriever_tools.HybridRetrieverChain")
254-
def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain):
270+
def test_initialize_verifies_configuration_parameters(
271+
self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder
272+
):
255273
"""Test that initialize passes correct configuration parameters."""
256274
tools = RetrieverTools()
257275

276+
mock_create_embed.return_value = Mock()
277+
mock_cross_encoder.return_value = Mock()
278+
258279
# Mock the HybridRetrieverChain instances
259280
mock_chains = []
260281
for i in range(6):
@@ -283,11 +304,18 @@ def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain):
283304
assert kwargs["weights"] == [0.6, 0.2, 0.2]
284305
assert kwargs["contextual_rerank"] is True
285306

307+
@patch("src.agents.retriever_tools.HuggingFaceCrossEncoder")
308+
@patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model")
286309
@patch("src.agents.retriever_tools.HybridRetrieverChain")
287-
def test_initialize_with_environment_variables(self, mock_hybrid_chain):
310+
def test_initialize_with_environment_variables(
311+
self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder
312+
):
288313
"""Test initialization respects environment variables."""
289314
tools = RetrieverTools()
290315

316+
mock_create_embed.return_value = Mock()
317+
mock_cross_encoder.return_value = Mock()
318+
291319
# Mock the HybridRetrieverChain instances
292320
mock_chains = []
293321
for i in range(6):
@@ -323,11 +351,18 @@ def test_tool_decorators_applied(self):
323351
assert hasattr(RetrieverTools.retrieve_yosys_rtdocs, "name")
324352
assert hasattr(RetrieverTools.retrieve_klayout_docs, "name")
325353

354+
@patch("src.agents.retriever_tools.HuggingFaceCrossEncoder")
355+
@patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model")
326356
@patch("src.agents.retriever_tools.HybridRetrieverChain")
327-
def test_different_docs_paths_for_retrievers(self, mock_hybrid_chain):
357+
def test_different_docs_paths_for_retrievers(
358+
self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder
359+
):
328360
"""Test that different retrievers use different document paths."""
329361
tools = RetrieverTools()
330362

363+
mock_create_embed.return_value = Mock()
364+
mock_cross_encoder.return_value = Mock()
365+
331366
# Mock the HybridRetrieverChain instances
332367
mock_chains = []
333368
for i in range(6):
@@ -369,11 +404,18 @@ def test_different_docs_paths_for_retrievers(self, mock_hybrid_chain):
369404
# Errinfo should have error-specific paths
370405
assert any("man3" in path for path in errinfo_paths)
371406

407+
@patch("src.agents.retriever_tools.HuggingFaceCrossEncoder")
408+
@patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model")
372409
@patch("src.agents.retriever_tools.HybridRetrieverChain")
373-
def test_html_docs_configuration(self, mock_hybrid_chain):
410+
def test_html_docs_configuration(
411+
self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder
412+
):
374413
"""Test HTML docs configuration for specific retrievers."""
375414
tools = RetrieverTools()
376415

416+
mock_create_embed.return_value = Mock()
417+
mock_cross_encoder.return_value = Mock()
418+
377419
# Mock the HybridRetrieverChain instances
378420
mock_chains = []
379421
for i in range(6):
@@ -426,11 +468,18 @@ def test_staticmethod_decorators(self):
426468
result = RetrieverTools.retrieve_general.invoke(input="test")
427469
assert result == ("", [], [], [])
428470

471+
@patch("src.agents.retriever_tools.HuggingFaceCrossEncoder")
472+
@patch("src.agents.retriever_tools.RetrieverTools._create_embedding_model")
429473
@patch("src.agents.retriever_tools.HybridRetrieverChain")
430-
def test_retriever_chain_create_hybrid_retriever_called(self, mock_hybrid_chain):
474+
def test_retriever_chain_create_hybrid_retriever_called(
475+
self, mock_hybrid_chain, mock_create_embed, mock_cross_encoder
476+
):
431477
"""Test that create_hybrid_retriever is called on all chains."""
432478
tools = RetrieverTools()
433479

480+
mock_create_embed.return_value = Mock()
481+
mock_cross_encoder.return_value = Mock()
482+
434483
# Mock the HybridRetrieverChain instances
435484
mock_chains = []
436485
for i in range(6):

backend/tests/test_similarity_retriever_chain.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ def test_create_vector_db_success(self, mock_faiss_db):
238238

239239
assert chain.vector_db == mock_db_instance
240240
mock_faiss_db.assert_called_once_with(
241-
embeddings_model_name="test-model", embeddings_type="HF", use_cuda=True
241+
embeddings_model_name="test-model",
242+
embeddings_type="HF",
243+
use_cuda=True,
244+
embedding_model=None,
242245
)
243246

244247
def test_create_vector_db_missing_config_raises_error(self):

0 commit comments

Comments
 (0)