Skip to content

Commit 19133fb

Browse files
committed
fix issue with aps endpoint
1 parent 82849c6 commit 19133fb

File tree

4 files changed

+50
-26
lines changed

4 files changed

+50
-26
lines changed

ap_monitor/app/db.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ def get_apclient_db_dep():
110110
finally:
111111
db.close()
112112

113+
def get_wireless_db_dep():
114+
"""FastAPI dependency for getting wireless_count DB session (generator, not context manager)."""
115+
db = WirelessSessionLocal()
116+
try:
117+
yield db
118+
finally:
119+
db.close()
120+
113121
def init_db():
114122
"""Initialize databases by creating tables."""
115123
try:
@@ -144,5 +152,6 @@ def init_db():
144152
'get_wireless_db_session',
145153
'get_apclient_db_session',
146154
'get_apclient_db_dep',
155+
'get_wireless_db_dep',
147156
'init_db'
148157
]

ap_monitor/app/main.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
init_db,
2222
WirelessBase,
2323
APClientBase,
24-
get_apclient_db_dep
24+
get_apclient_db_dep,
25+
get_wireless_db_dep
2526
)
2627
from ap_monitor.app.models import (
2728
Campus, Building, ClientCount,
@@ -714,13 +715,12 @@ def insert_apclientcount_data(device_info_list, timestamp, session=None):
714715
raise
715716

716717
@app.get("/aps", response_model=List[dict], tags=["Access Points"])
717-
def get_aps(db: Session = Depends(get_wireless_db)):
718+
def get_aps(db: Session = Depends(get_wireless_db_dep)):
718719
"""Get all access points from the database."""
719720
try:
720721
logger.info("Fetching AP data from the database")
721722
aps = db.query(AccessPoint).all()
722723
logger.info(f"Retrieved {len(aps)} AP records")
723-
724724
return [{
725725
"apid": ap.apid,
726726
"apname": ap.apname,
@@ -774,7 +774,7 @@ def get_client_counts(
774774
raise HTTPException(status_code=500, detail=str(e))
775775

776776
@app.get("/buildings", response_model=List[dict], tags=["Buildings"])
777-
def get_buildings(db: Session = Depends(get_wireless_db)):
777+
def get_buildings(db: Session = Depends(get_wireless_db_dep)):
778778
"""Get list of buildings with their details."""
779779
try:
780780
logger.info("Fetching list of buildings")
@@ -792,7 +792,7 @@ def get_buildings(db: Session = Depends(get_wireless_db)):
792792
raise HTTPException(status_code=500, detail="Internal server error")
793793

794794
@app.get("/floors/{building_id}", response_model=List[dict], tags=["Floors"])
795-
def get_floors(building_id: int, db: Session = Depends(get_wireless_db)):
795+
def get_floors(building_id: int, db: Session = Depends(get_wireless_db_dep)):
796796
"""Get floors for a specific building."""
797797
try:
798798
floors = db.query(Floor).filter_by(buildingid=building_id).all()
@@ -810,7 +810,7 @@ def get_floors(building_id: int, db: Session = Depends(get_wireless_db)):
810810
raise HTTPException(status_code=500, detail="Internal server error")
811811

812812
@app.get("/rooms/{floor_id}", response_model=List[dict], tags=["Rooms"])
813-
def get_rooms(floor_id: int, db: Session = Depends(get_wireless_db)):
813+
def get_rooms(floor_id: int, db: Session = Depends(get_wireless_db_dep)):
814814
"""Get rooms for a specific floor."""
815815
try:
816816
rooms = db.query(Room).filter_by(floorid=floor_id).all()
@@ -827,7 +827,7 @@ def get_rooms(floor_id: int, db: Session = Depends(get_wireless_db)):
827827
raise HTTPException(status_code=500, detail="Internal server error")
828828

829829
@app.get("/radio-types", response_model=List[dict], tags=["Radio Types"])
830-
def get_radio_types(db: Session = Depends(get_wireless_db)):
830+
def get_radio_types(db: Session = Depends(get_wireless_db_dep)):
831831
"""Get all radio types."""
832832
try:
833833
radio_types = db.query(RadioType).all()
@@ -843,7 +843,7 @@ def get_radio_types(db: Session = Depends(get_wireless_db)):
843843
raise HTTPException(status_code=500, detail="Internal server error")
844844

845845
@app.post("/wireless/campuses/", response_model=CampusResponse)
846-
def create_campus(campus: CampusCreate, db: Session = Depends(get_wireless_db)):
846+
def create_campus(campus: CampusCreate, db: Session = Depends(get_wireless_db_dep)):
847847
"""Create a new campus."""
848848
db_campus = Campus(campus_name=campus.campus_name)
849849
db.add(db_campus)
@@ -852,12 +852,12 @@ def create_campus(campus: CampusCreate, db: Session = Depends(get_wireless_db)):
852852
return db_campus
853853

854854
@app.get("/wireless/campuses/", response_model=List[CampusResponse])
855-
def get_campuses(db: Session = Depends(get_wireless_db)):
855+
def get_campuses(db: Session = Depends(get_wireless_db_dep)):
856856
"""Get all campuses."""
857857
return db.query(Campus).all()
858858

859859
@app.post("/wireless/buildings/", response_model=BuildingResponse)
860-
def create_building(building: BuildingCreate, db: Session = Depends(get_wireless_db)):
860+
def create_building(building: BuildingCreate, db: Session = Depends(get_wireless_db_dep)):
861861
"""Create a new building."""
862862
db_building = Building(**building.dict())
863863
db.add(db_building)
@@ -866,15 +866,15 @@ def create_building(building: BuildingCreate, db: Session = Depends(get_wireless
866866
return db_building
867867

868868
@app.get("/wireless/buildings/", response_model=List[BuildingResponse])
869-
def get_wireless_buildings(campus_id: Optional[int] = None, db: Session = Depends(get_wireless_db)):
869+
def get_wireless_buildings(campus_id: Optional[int] = None, db: Session = Depends(get_wireless_db_dep)):
870870
"""Get all buildings, optionally filtered by campus."""
871871
query = db.query(Building)
872872
if campus_id:
873873
query = query.filter(Building.campus_id == campus_id)
874874
return query.all()
875875

876876
@app.post("/wireless/client-counts/", response_model=ClientCountResponse)
877-
def create_client_count(count: ClientCountCreate, db: Session = Depends(get_wireless_db)):
877+
def create_client_count(count: ClientCountCreate, db: Session = Depends(get_wireless_db_dep)):
878878
"""Create a new client count."""
879879
db_count = ClientCount(**count.dict())
880880
db.add(db_count)
@@ -887,7 +887,7 @@ def get_wireless_client_counts(
887887
building_id: Optional[int] = None,
888888
start_time: Optional[datetime] = None,
889889
end_time: Optional[datetime] = None,
890-
db: Session = Depends(get_wireless_db)
890+
db: Session = Depends(get_wireless_db_dep)
891891
):
892892
"""Get client counts with optional filters."""
893893
query = db.query(ClientCount)

ap_monitor/tests/conftest.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
)
1919
from ap_monitor.app.db import (
2020
get_wireless_db,
21-
get_apclient_db
21+
get_apclient_db,
22+
get_wireless_db_dep,
23+
get_apclient_db_dep
2224
)
2325
from ap_monitor.app.main import app
2426

@@ -62,26 +64,19 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
6264
@pytest.fixture(autouse=True)
6365
def create_test_db():
6466
# Import wireless models before creating wireless tables
65-
from ap_monitor.app.models import Building, Campus, ClientCount, WirelessBase
67+
from ap_monitor.app.models import Building, Campus, ClientCount, WirelessBase, ApBuilding, Floor, Room, AccessPoint, ClientCountAP, RadioType, APClientBase
6668
WirelessBase.metadata.drop_all(bind=wireless_engine)
6769
WirelessBase.metadata.create_all(bind=wireless_engine)
68-
69-
# Import apclient models before creating apclient tables
70-
from ap_monitor.app.models import ApBuilding, Floor, Room, AccessPoint, ClientCountAP, RadioType, APClientBase
7170
APClientBase.metadata.drop_all(bind=apclient_engine)
7271
APClientBase.metadata.create_all(bind=apclient_engine)
73-
7472
# Verify tables are created correctly
7573
inspector = inspect(apclient_engine)
7674
tables = inspector.get_table_names()
7775
print(f"Tables in apclient_engine: {tables}")
78-
79-
# Verify table schemas
8076
for table_name in ['buildings', 'floors', 'rooms', 'accesspoints', 'clientcount', 'radiotypes']:
8177
assert table_name in tables, f"{table_name} table not created"
8278
columns = [col['name'] for col in inspector.get_columns(table_name)]
8379
print(f"Columns in {table_name}: {columns}")
84-
8580
# Add default radio types
8681
with APClientSessionLocal() as session:
8782
if not session.query(RadioType).first():
@@ -91,9 +86,7 @@ def create_test_db():
9186
RadioType(radioname="radio2", radioid=3)
9287
])
9388
session.commit()
94-
9589
yield
96-
9790
WirelessBase.metadata.drop_all(bind=wireless_engine)
9891
APClientBase.metadata.drop_all(bind=apclient_engine)
9992

@@ -131,8 +124,22 @@ def override_get_apclient_db():
131124
finally:
132125
pass
133126

127+
def override_get_wireless_db_dep():
128+
try:
129+
yield wireless_db
130+
finally:
131+
pass
132+
133+
def override_get_apclient_db_dep():
134+
try:
135+
yield apclient_db
136+
finally:
137+
pass
138+
134139
app.dependency_overrides[get_wireless_db] = override_get_wireless_db
135140
app.dependency_overrides[get_apclient_db] = override_get_apclient_db
141+
app.dependency_overrides[get_wireless_db_dep] = override_get_wireless_db_dep
142+
app.dependency_overrides[get_apclient_db_dep] = override_get_apclient_db_dep
136143

137144
# Add scheduler to app state
138145
app.state.scheduler = scheduler

ap_monitor/tests/test_main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,17 @@ def override_get_db_with_mock_ap():
8383
def override():
8484
yield mock_session
8585

86+
from ap_monitor.app.db import get_wireless_db_dep, get_apclient_db_dep
8687
app.dependency_overrides[get_wireless_db] = override
8788
app.dependency_overrides[get_apclient_db] = override
89+
app.dependency_overrides[get_wireless_db_dep] = override
90+
app.dependency_overrides[get_apclient_db_dep] = override
8891
yield
8992
app.dependency_overrides.clear()
9093

9194
@pytest.fixture
9295
def override_get_db_with_mock_buildings():
9396
mock_building = MagicMock()
94-
# Patch to match the API's expected attribute names and values
9597
mock_building.building_id = 1
9698
mock_building.building_name = "BuildingA"
9799

@@ -104,8 +106,11 @@ def override_get_db_with_mock_buildings():
104106
def override():
105107
yield mock_session
106108

109+
from ap_monitor.app.db import get_wireless_db_dep, get_apclient_db_dep
107110
app.dependency_overrides[get_wireless_db] = override
108111
app.dependency_overrides[get_apclient_db] = override
112+
app.dependency_overrides[get_wireless_db_dep] = override
113+
app.dependency_overrides[get_apclient_db_dep] = override
109114
yield
110115
app.dependency_overrides.clear()
111116

@@ -353,10 +358,13 @@ def override_get_db_with_mock_aps():
353358
mock_session.query.return_value = mock_query
354359

355360
def override():
356-
return mock_session
361+
yield mock_session
357362

363+
from ap_monitor.app.db import get_wireless_db_dep, get_apclient_db_dep
358364
app.dependency_overrides[get_wireless_db] = override
359365
app.dependency_overrides[get_apclient_db] = override
366+
app.dependency_overrides[get_wireless_db_dep] = override
367+
app.dependency_overrides[get_apclient_db_dep] = override
360368
yield
361369
app.dependency_overrides.clear()
362370

0 commit comments

Comments
 (0)