-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsql_agent.py
More file actions
134 lines (109 loc) · 4.77 KB
/
sql_agent.py
File metadata and controls
134 lines (109 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import sqlite3
import pandas as pd
import json
import os
from openai import OpenAI
from dotenv import load_dotenv # <--- Add this
load_dotenv()
# 1. SETUP: Connect to the database
DB_NAME = "ohm_sweet_ohm.db"
# check_same_thread=False is needed for some web frameworks/notebooks
conn = sqlite3.connect(DB_NAME, check_same_thread=False)
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# 2. CONTEXT: The Schema (Sanitized for the LLM)
database_schema = """
Tables available:
1. products (product_id, name, description, price, category, in_stock)
- Note: Valid Categories are 'Audio', 'Wearables', 'Accessories', 'Office', 'Gaming', 'TVs'.
- Note: When users ask for specific features (e.g. "wireless", "noise-cancelling"), check the 'description' column using LIKE.
- Note: 'in_stock' is a boolean flag (0=discontinued, 1=active).
2. stores (store_id, name, address, phone)
- Note: Store names are 'Ohm Sweet Ohm Downtown', 'Ohm Sweet Ohm Union Square', 'Ohm Sweet Ohm Mission District', 'Ohm Sweet Ohm Marina', 'Ohm Sweet Ohm SoMa'.
3. store_inventory (store_id, product_id, stock_level)
- Note: This table links Stores to Products. 'stock_level' is the actual quantity available on the shelf.
4. orders (order_id, customer_name, customer_email, status, days_since_order, current_location)
- Note: 'status' values: 'pending', 'processing', 'shipped', 'in_transit', 'delivered'.
- Note: There is NO date column. Use 'days_since_order' (integer) for time calculations.
5. order_items (order_id, product_id, quantity, unit_price)
6. promotions (promotion_id, description, discount_percent, discount_amount, category, product_ids, active)
- Note: A promotion will have EITHER a 'discount_percent' OR a 'discount_amount'. Check both columns.
"""
# 3. TOOL: The Python Function
def run_sql_query(query):
"""Executes a read-only SQL query."""
try:
# Safety: Prevent modification
if "DROP" in query.upper() or "DELETE" in query.upper() or "INSERT" in query.upper():
return "Error: Read-only access."
print(f"\nrunning SQL: {query}") # Debug print to see it working
df = pd.read_sql_query(query, conn)
if df.empty:
return "Query ran successfully but returned no results."
return df.to_markdown(index=False)
except Exception as e:
return f"SQL Error: {str(e)}"
# 4. TOOL CONFIG: The JSON for OpenAI
tools = [
{
"type": "function",
"function": {
"name": "run_sql_query",
"description": "Execute a SQL query to retrieve data.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": f"The SQL query to run. Schema:\n{database_schema}"
}
},
"required": ["query"]
}
}
}
]
# 5. THE AGENT LOOP (The Logic)
def chat_with_sql_agent(user_question):
print(f"\n👤 User: {user_question}")
# A. Initial Message
messages = [
{"role": "system", "content": "You are a data assistant. You must check the database to answer questions."},
{"role": "user", "content": user_question}
]
# B. First Call: Ask LLM what to do
response = client.chat.completions.create(
model="gpt-4o", # or gpt-3.5-turbo
messages=messages,
tools=tools,
tool_choice="auto"
)
msg = response.choices[0].message
# C. Check if LLM wants to use a tool
if msg.tool_calls:
# 1. Add the "Thought" to history
messages.append(msg)
# 2. Run the Tool
tool_call = msg.tool_calls[0] # We assume 1 tool call for simplicity
if tool_call.function.name == "run_sql_query":
# Extract SQL from JSON
arguments = json.loads(tool_call.function.arguments)
sql_query = arguments['query']
# Execute
result_data = run_sql_query(sql_query)
# --- ADD THIS PRINT STATEMENT ---
print(f"🔍 DEBUG - Tool Output: {result_data}")
# --------------------------------
# 3. Add the "Result" to history
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": result_data
})
# If no tool was called, just return the text
return msg.content
# 6. TEST IT
if __name__ == "__main__":
# Test 1: Simple lookup
print("🤖 Agent:", chat_with_sql_agent("How many headphones are in the Downtown store?"))
# Test 2: Complex Join
# print("🤖 Agent:", chat_with_sql_agent("Which store has the most stock of cables?"))