forked from wenyl22/contextual_privacy_defense
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearch_agent.py
More file actions
263 lines (231 loc) · 9.91 KB
/
search_agent.py
File metadata and controls
263 lines (231 loc) · 9.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import json
import os
import dotenv
from openai import AzureOpenAI, OpenAI, AsyncAzureOpenAI, AsyncOpenAI
import litellm
litellm.num_retries = 3
from typing import List, Dict
from utils import retry
dotenv.load_dotenv()
class SearchAgent:
"""
A basic chat agent supporting multi-round conversations with an LLM.
Features:
1. Set system message.
2. Load conversation history from a JSON file.
3. Save conversation history to a JSON file.
4. Receive user query, prompt the LLM with history + query,
and return the assistant response.
"""
def __init__(self, model: str = "gemini/gemini-2.5-pro", temperature: float = 1.0, provider: str = "azure", budget_tokens: int = 1024):
if provider == "azure":
self.openai = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)
self.async_openai = AsyncAzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)
model = model.split('azure/')[-1]
elif provider == "openai":
self.openai = OpenAI(
api_key=os.getenv("PERSONAL_OAI_API_KEY")
)
self.async_openai = AsyncOpenAI(
api_key=os.getenv("PERSONAL_OAI_API_KEY")
)
self.model = "gpt-4.1"
elif provider == "gemini":
self.openai = OpenAI(
api_key=os.getenv("GEMINI_API_KEY"),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)
self.async_openai = AsyncOpenAI(
api_key=os.getenv("GEMINI_API_KEY"),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)
elif provider == "local_vllm":
if "Qwen3-8B" in model:
self.openai = OpenAI(
api_key="EMPTY",
base_url="http://jagupard36:8001/v1"
)
self.async_openai = AsyncOpenAI(
api_key="EMPTY",
base_url="http://jagupard36:8001/v1"
)
elif "Qwen3-32B" in model:
self.openai = OpenAI(
api_key="EMPTY",
base_url="http://sphinx10:8001/v1"
)
self.async_openai = AsyncOpenAI(
api_key="EMPTY",
base_url="http://sphinx10:8001/v1"
)
else:
raise ValueError(f"Invalid provider: {provider}")
self.model = model
self.temperature = temperature
self.messages: List[Dict[str, str]] = []
self.budget_tokens = budget_tokens
self.reasoning_effort = None
if self.model.endswith("-high") or self.model.endswith("-medium") or self.model.endswith("-low"):
self.reasoning_effort = self.model.split("-")[-1]
self.model = self.model.rsplit("-", 1)[0]
def set_system_message(self, system_message: str) -> None:
"""Set or replace the system prompt at the start of the conversation."""
# Remove any existing system messages
self.messages = [m for m in self.messages if m.get("role") != "system"]
# Insert new system message at the beginning
self.messages.insert(0, {"role": "system", "content": system_message})
def load_history(self, filepath: str) -> None:
"""Load conversation history from a JSON file (sync version)."""
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
self.messages = data
else:
raise ValueError("History file must contain a list of messages")
except (IOError, json.JSONDecodeError) as e:
print(f"Error loading history: {e}")
async def load_history_async(self, filepath: str) -> None:
"""Load conversation history from a JSON file (async version)."""
try:
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
self.messages = data
else:
raise ValueError("History file must contain a list of messages")
except (IOError, json.JSONDecodeError) as e:
print(f"Error loading history: {e}")
def load_history_from_list(self, messages: List[Dict[str, str]]) -> None:
"""Load conversation history from a list of messages."""
self.messages = messages
def save_history(self, filepath: str) -> None:
"""Save the current conversation history to a JSON file (sync version)."""
try:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(self.messages, f, ensure_ascii=False, indent=2)
except IOError as e:
print(f"Error saving history: {e}")
async def save_history_async(self, filepath: str) -> None:
"""Save the current conversation history to a JSON file (async version)."""
try:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(self.messages, f, ensure_ascii=False, indent=2)
except IOError as e:
print(f"Error saving history: {e}")
def chat(self, user_query: str) -> str:
"""
Add a user query to the conversation, call the LLM, store and return its response (sync version).
"""
# Append user's message
self.messages.append({"role": "user", "content": user_query})
# Send to OpenAI ChatCompletion
reasoning_effort = self.reasoning_effort
print(f"Using model: {self.model} with reasoning_effort: {reasoning_effort}")
if reasoning_effort:
response = self.openai.chat.completions.create(
model=self.model,
messages=self.messages,
temperature=self.temperature,
reasoning_effort=reasoning_effort
)
else:
response = self.openai.chat.completions.create(
model=self.model,
messages=self.messages,
temperature=self.temperature
)
# if "gemini" in self.model:
# response = litellm.completion(
# model=self.model,
# messages=self.messages,
# thinking={"type": "enabled", "budget_tokens": self.budget_tokens}
# )
# elif "gpt" in self.model:
# response = litellm.completion(
# model=self.model,
# messages=self.messages
# )
# else:
# raise ValueError(f"Invalid model: {self.model}")
assistant_message = response.choices[0].message.content
# Append assistant's reply to history
self.messages.append({"role": "assistant", "content": assistant_message})
# Print usage
print(f"Usage: {response.usage}")
return assistant_message
async def chat_async(self, user_query: str) -> str:
"""
Add a user query to the conversation, call the LLM, store and return its response (async version).
"""
# Append user's message
self.messages.append({"role": "user", "content": user_query})
# Send to OpenAI ChatCompletion
reasoning_effort = self.reasoning_effort
if reasoning_effort:
response = await self.async_openai.chat.completions.create(
model=self.model,
messages=self.messages,
temperature=self.temperature,
reasoning_effort=reasoning_effort
)
else:
response = await self.async_openai.chat.completions.create(
model=self.model,
messages=self.messages,
temperature=self.temperature
)
# if "gemini" in self.model:
# response = await litellm.acompletion(
# model=self.model,
# messages=self.messages,
# thinking={"type": "enabled", "budget_tokens": self.budget_tokens}
# )
# elif "gpt" in self.model:
# response = await litellm.acompletion(
# model=self.model,
# messages=self.messages
# )
# else:
# raise ValueError(f"Invalid model: {self.model}")
assistant_message = response.choices[0].message.content
# Append assistant's reply to history
self.messages.append({"role": "assistant", "content": assistant_message})
# Print usage
# print("================================")
# print(f"Message: {user_query}")
# print(f"Response: {assistant_message}")
# print(f"Usage: {response.usage}")
# print("================================")
return assistant_message
if __name__ == "__main__":
# Example usage
agent = SearchAgent(model = "azure/gpt-5-mini-250807-65987", provider = "azure")
agent.set_system_message("""
You are a helpful assistant. Your goal is to guess a number between 1 and 10.
If the user says "higher", you should guess a higher number.
If the user says "lower", you should guess a lower number.
If the user says "correct", you should stop guessing.
Return your guess in the following format:
<guess><number></guess>
Your initial guess is 5.
""")
# Load previous history if exists
# agent.load_history("history.json")
while True:
query = input("User: ")
if query.lower() in {"exit", "quit"}:
print("Exiting chat.")
break
reply = agent.chat(query)
print(f"Assistant: {reply}\n")
# Save conversation history on exit
agent.save_history("history.json")