Skip to content

Commit c85b110

Browse files
fix: share ML model instances to reduce startup time
The backend startup was slow because RetrieverTools.initialize() creates 6 retriever chains, and each one independently loaded its own copy of the embedding model (thenlper/gte-large) and reranker model (BAAI/bge-reranker-base). That meant 12 heavy model loads when only 2 are actually needed, since all chains use the same model config. This fix creates both models once at the top of initialize() and passes the shared instances down through HybridRetrieverChain, SimilarityRetrieverChain, and FAISSVectorDatabase. Both models are stateless (they only run encode/score inference) so sharing a single instance across all chains is safe. Each chain still builds its own independent FAISS index with its own documents. Startup model loading goes from ~34s to ~7s on a local machine (4.9x). Resolves #88 Signed-off-by: Harsh Kumar <harshkumar3446@gmail.com>
1 parent decfd26 commit c85b110

File tree

6 files changed

+201
-16
lines changed

6 files changed

+201
-16
lines changed

backend/src/agents/retriever_tools.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import os
2+
import logging
23
from typing import Tuple, Optional, Union
34
from dotenv import load_dotenv
45

56
from langchain_core.tools import tool
67
from langchain.retrievers import EnsembleRetriever
78
from langchain.retrievers import ContextualCompressionRetriever
9+
from langchain_huggingface import HuggingFaceEmbeddings
10+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
11+
from langchain_google_vertexai import VertexAIEmbeddings
12+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
813

914
from ..chains.hybrid_retriever_chain import HybridRetrieverChain
1015
from ..tools.format_docs import format_docs
@@ -39,13 +44,49 @@ def __init__(self) -> None:
3944
]
4045
tool_descriptions: str = ""
4146

47+
@staticmethod
48+
def _create_embedding_model(
49+
embeddings_config: dict[str, str],
50+
use_cuda: bool = False,
51+
) -> Union[HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings]:
52+
embeddings_type = embeddings_config["type"]
53+
embeddings_model_name = embeddings_config["name"]
54+
55+
if embeddings_type == "GOOGLE_GENAI":
56+
logging.info("Using Google GenerativeAI embeddings...")
57+
return GoogleGenerativeAIEmbeddings(
58+
model=embeddings_model_name,
59+
task_type="retrieval_document",
60+
)
61+
elif embeddings_type == "GOOGLE_VERTEXAI":
62+
logging.info("Using Google VertexAI embeddings...")
63+
return VertexAIEmbeddings(model_name=embeddings_model_name)
64+
elif embeddings_type == "HF":
65+
logging.info("Using HuggingFace embeddings...")
66+
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
67+
return HuggingFaceEmbeddings(
68+
model_name=embeddings_model_name,
69+
multi_process=False,
70+
encode_kwargs={"normalize_embeddings": True},
71+
model_kwargs=model_kwargs,
72+
)
73+
else:
74+
raise ValueError("Invalid embeddings type specified.")
75+
4276
def initialize(
4377
self,
4478
embeddings_config: dict[str, str],
4579
reranking_model_name: str,
4680
use_cuda: bool = False,
4781
fast_mode: bool = False,
4882
) -> None:
83+
# Create shared model instances once
84+
embedding_model = self._create_embedding_model(embeddings_config, use_cuda)
85+
logging.info("Shared embedding model created.")
86+
87+
reranker_model = HuggingFaceCrossEncoder(model_name=reranking_model_name)
88+
logging.info("Shared reranker model created.")
89+
4990
markdown_docs_map = {
5091
"general": [
5192
"./data/markdown/OR_docs",
@@ -100,6 +141,8 @@ def initialize(
100141
contextual_rerank=True,
101142
search_k=search_k,
102143
chunk_size=chunk_size,
144+
embedding_model=embedding_model,
145+
reranker_model=reranker_model,
103146
)
104147
general_retriever_chain.create_hybrid_retriever()
105148
RetrieverTools.general_retriever = general_retriever_chain.retriever
@@ -115,6 +158,8 @@ def initialize(
115158
contextual_rerank=True,
116159
search_k=search_k,
117160
chunk_size=chunk_size,
161+
embedding_model=embedding_model,
162+
reranker_model=reranker_model,
118163
)
119164
install_retriever_chain.create_hybrid_retriever()
120165
RetrieverTools.install_retriever = install_retriever_chain.retriever
@@ -131,6 +176,8 @@ def initialize(
131176
contextual_rerank=True,
132177
search_k=search_k,
133178
chunk_size=chunk_size,
179+
embedding_model=embedding_model,
180+
reranker_model=reranker_model,
134181
)
135182
commands_retriever_chain.create_hybrid_retriever()
136183
RetrieverTools.commands_retriever = commands_retriever_chain.retriever
@@ -146,6 +193,8 @@ def initialize(
146193
contextual_rerank=True,
147194
search_k=search_k,
148195
chunk_size=chunk_size,
196+
embedding_model=embedding_model,
197+
reranker_model=reranker_model,
149198
)
150199
yosys_rtdocs_retriever_chain.create_hybrid_retriever()
151200
RetrieverTools.yosys_rtdocs_retriever = yosys_rtdocs_retriever_chain.retriever
@@ -161,6 +210,8 @@ def initialize(
161210
contextual_rerank=True,
162211
search_k=search_k,
163212
chunk_size=chunk_size,
213+
embedding_model=embedding_model,
214+
reranker_model=reranker_model,
164215
)
165216
klayout_retriever_chain.create_hybrid_retriever()
166217
RetrieverTools.klayout_retriever = klayout_retriever_chain.retriever
@@ -176,6 +227,8 @@ def initialize(
176227
contextual_rerank=True,
177228
search_k=search_k,
178229
chunk_size=chunk_size,
230+
embedding_model=embedding_model,
231+
reranker_model=reranker_model,
179232
)
180233
errinfo_retriever_chain.create_hybrid_retriever()
181234
RetrieverTools.errinfo_retriever = errinfo_retriever_chain.retriever

backend/src/chains/hybrid_retriever_chain.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from langchain.retrievers import ContextualCompressionRetriever
66
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
77
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
8-
from langchain_google_vertexai import ChatVertexAI
9-
from langchain_google_genai import ChatGoogleGenerativeAI
8+
from langchain_huggingface import HuggingFaceEmbeddings
9+
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
10+
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
1011
from langchain_ollama import ChatOllama
1112
from langchain.retrievers.document_compressors.cross_encoder_rerank import (
1213
CrossEncoderReranker,
@@ -38,6 +39,14 @@ def __init__(
3839
weights: list[float] = [0.33, 0.33, 0.33],
3940
chunk_size: int = 500,
4041
contextual_rerank: bool = False,
42+
embedding_model: Optional[
43+
Union[
44+
HuggingFaceEmbeddings,
45+
GoogleGenerativeAIEmbeddings,
46+
VertexAIEmbeddings,
47+
]
48+
] = None,
49+
reranker_model: Optional[HuggingFaceCrossEncoder] = None,
4150
):
4251
super().__init__(
4352
llm_model=llm_model,
@@ -48,6 +57,14 @@ def __init__(
4857

4958
self.reranking_model_name: Optional[str] = reranking_model_name
5059
self.use_cuda: bool = use_cuda
60+
self.embedding_model: Optional[
61+
Union[
62+
HuggingFaceEmbeddings,
63+
GoogleGenerativeAIEmbeddings,
64+
VertexAIEmbeddings,
65+
]
66+
] = embedding_model
67+
self.reranker_model: Optional[HuggingFaceCrossEncoder] = reranker_model
5168

5269
self.search_k: int = search_k
5370
self.weights: list[float] = weights
@@ -74,6 +91,7 @@ def create_hybrid_retriever(self) -> None:
7491
html_docs_path=self.html_docs_path,
7592
chunk_size=self.chunk_size,
7693
use_cuda=self.use_cuda,
94+
embedding_model=self.embedding_model,
7795
)
7896
if self.vector_db is None:
7997
cur_path = os.path.abspath(__file__)
@@ -121,8 +139,11 @@ def create_hybrid_retriever(self) -> None:
121139
)
122140

123141
if self.contextual_rerank:
142+
reranker = self.reranker_model or HuggingFaceCrossEncoder(
143+
model_name=self.reranking_model_name
144+
)
124145
compressor = CrossEncoderReranker(
125-
model=HuggingFaceCrossEncoder(model_name=self.reranking_model_name),
146+
model=reranker,
126147
top_n=self.search_k,
127148
)
128149
self.retriever = ContextualCompressionRetriever(

backend/src/chains/similarity_retriever_chain.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
55
from langchain.docstore.document import Document
6-
from langchain_google_vertexai import ChatVertexAI
7-
from langchain_google_genai import ChatGoogleGenerativeAI
6+
from langchain_huggingface import HuggingFaceEmbeddings
7+
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
8+
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
89
from langchain_ollama import ChatOllama
910

1011
from ..vectorstores.faiss import FAISSVectorDatabase
@@ -28,6 +29,13 @@ def __init__(
2829
embeddings_config: Optional[dict[str, str]] = None,
2930
use_cuda: bool = False,
3031
chunk_size: int = 500,
32+
embedding_model: Optional[
33+
Union[
34+
HuggingFaceEmbeddings,
35+
GoogleGenerativeAIEmbeddings,
36+
VertexAIEmbeddings,
37+
]
38+
] = None,
3139
):
3240
super().__init__(
3341
llm_model=llm_model,
@@ -40,6 +48,13 @@ def __init__(
4048

4149
self.embeddings_config: Optional[dict[str, str]] = embeddings_config
4250
self.use_cuda: bool = use_cuda
51+
self.embedding_model: Optional[
52+
Union[
53+
HuggingFaceEmbeddings,
54+
GoogleGenerativeAIEmbeddings,
55+
VertexAIEmbeddings,
56+
]
57+
] = embedding_model
4358

4459
self.markdown_docs_path: Optional[list[str]] = markdown_docs_path
4560
self.other_docs_path: Optional[list[str]] = other_docs_path
@@ -125,6 +140,7 @@ def create_vector_db(self) -> None:
125140
embeddings_model_name=self.embeddings_config["name"],
126141
embeddings_type=self.embeddings_config["type"],
127142
use_cuda=self.use_cuda,
143+
embedding_model=self.embedding_model,
128144
)
129145
else:
130146
raise ValueError("Embeddings model config not provided correctly.")

backend/src/vectorstores/faiss.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,21 @@ def __init__(
2828
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
2929
debug: bool = False,
3030
use_cuda: bool = False,
31+
embedding_model: Optional[
32+
Union[
33+
HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings
34+
]
35+
] = None,
3136
):
3237
self.embeddings_model_name = embeddings_model_name
3338

34-
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
35-
3639
self.embedding_model: Union[
3740
HuggingFaceEmbeddings, GoogleGenerativeAIEmbeddings, VertexAIEmbeddings
3841
]
3942

40-
if embeddings_type == "GOOGLE_GENAI":
43+
if embedding_model is not None:
44+
self.embedding_model = embedding_model
45+
elif embeddings_type == "GOOGLE_GENAI":
4146
self.embedding_model = GoogleGenerativeAIEmbeddings(
4247
model=self.embeddings_model_name,
4348
task_type="retrieval_document",
@@ -51,6 +56,7 @@ def __init__(
5156
logging.info("Using Google VertexAI embeddings...")
5257

5358
elif embeddings_type == "HF":
59+
model_kwargs = {"device": "cuda"} if use_cuda else {"device": "cpu"}
5460
self.embedding_model = HuggingFaceEmbeddings(
5561
model_name=self.embeddings_model_name,
5662
multi_process=False,

0 commit comments

Comments
 (0)