-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
138 lines (110 loc) · 3.97 KB
/
Copy pathapi.py
File metadata and controls
138 lines (110 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
import os
from typing import List, Optional
app = FastAPI()
# Enable CORS for local development
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_PATH = "cost_model.pkl"
model = None
# Reusing the extraction logic (in a real app, this should be in a shared module)
def extract_features(title, description, tags):
# Ensure tags is a list of lowercase strings
if isinstance(tags, str):
tags = [t.strip().lower() for t in tags.split(",")]
else:
tags = [t.lower() for t in tags]
title = title or ""
description = description or ""
# Combine all text for keyword search
full_text = (title + " " + description).lower()
common_tags = [
"frontend", "backend", "ai", "blockchain",
"security", "ui/ux", "devops", "marketing",
"qa", "analytics", "mobile"
]
features = [
len(title),
len(description),
len(tags)
]
for tag_keyword in common_tags:
# Check if tag is in tags list OR if keyword appears in text
has_tag = (tag_keyword in tags) or (tag_keyword in full_text)
features.append(int(has_tag))
return features
@app.on_event("startup")
def load_model():
global model
if os.path.exists(MODEL_PATH):
model = joblib.load(MODEL_PATH)
print("✅ Model loaded successfully")
else:
print("⚠️ Model not found. Please run train_model.py first.")
class TaskInput(BaseModel):
title: str
description: str
tags: List[str]
class UpdateInput(BaseModel):
title: str
description: str
tags: List[str]
actual_cost: float
@app.get("/")
def read_root():
return {"status": "ML Service Running"}
@app.post("/predict")
def predict(task: TaskInput):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
features = extract_features(task.title, task.description, task.tags)
# Predict
# Note: model is a Pipeline(StandardScaler, SGDRegressor)
predicted_cost = model.predict([features])[0]
# Ensure non-negative
predicted_cost = max(10.0, predicted_cost)
return {"predicted_cost": int(predicted_cost)}
@app.post("/update")
def update(task: UpdateInput):
global model
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
features = extract_features(task.title, task.description, task.tags)
X_new = np.array([features])
y_new = np.array([task.actual_cost])
# Partial fit for SGDRegressor
# Access the regressor step inside the pipeline
# Pipeline steps: ['standardscaler', 'sgdregressor']
# For a pipeline, we can't easily call partial_fit on the whole pipeline if StandardScaler needs global stats.
# However, for online learning, we often just retrain or update the regressor.
# A simplified approach for this "update" endpoint without retraining the whole scaler:
# We will just assume the scaler is fixed for now or try to update if possible.
# Standard scaler supports partial_fit.
try:
scaler = model.named_steps['standardscaler']
regressor = model.named_steps['sgdregressor']
scaler.partial_fit(X_new)
X_scaled = scaler.transform(X_new)
regressor.partial_fit(X_scaled, y_new)
# Save updated model
joblib.dump(model, MODEL_PATH)
return {"status": "model updated", "new_cost_learned": task.actual_cost}
except Exception as e:
print(f"Error updating model: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api:app",
host="0.0.0.0",
port=int(os.environ.get("PORT", 8000))
)