Skip to content

Commit 2ddd4c9

Browse files
perf: run sync modules in bounded thread pool executor
Move sync forward() and batch() calls off the event loop into a dedicated ThreadPoolExecutor with context variable propagation. Keeps the existing hasattr(instance, 'aforward') check for async detection. Configurable via --sync-workers flag or server.sync_worker_threads in dspy.config.yaml. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
1 parent 741ea52 commit 2ddd4c9

File tree

7 files changed

+178
-3
lines changed

7 files changed

+178
-3
lines changed

src/dspy_cli/commands/serve.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,13 @@ def _exec_clean(target_python: Path, args: list[str]) -> NoReturn:
102102
default=False,
103103
help="Enable API authentication via DSPY_API_KEY (default: disabled)",
104104
)
105-
def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, system, mcp, auth):
105+
@click.option(
106+
"--sync-workers",
107+
default=None,
108+
type=click.IntRange(1, 200),
109+
help="Number of threads for sync module execution (default: min(32, cpu+4))",
110+
)
111+
def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, system, mcp, auth, sync_workers):
106112
"""Start an HTTP API server that exposes your DSPy programs.
107113
108114
This command:
@@ -127,6 +133,7 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy
127133
openapi_format=openapi_format,
128134
mcp=mcp,
129135
auth=auth,
136+
sync_workers=sync_workers,
130137
)
131138
return
132139

@@ -192,6 +199,8 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy
192199
args.append("--mcp")
193200
if auth:
194201
args.append("--auth")
202+
if sync_workers is not None:
203+
args.extend(["--sync-workers", str(sync_workers)])
195204

196205
_exec_clean(target_python, args)
197206
else:
@@ -205,4 +214,5 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy
205214
openapi_format=openapi_format,
206215
mcp=mcp,
207216
auth=auth,
217+
sync_workers=sync_workers,
208218
)

src/dspy_cli/server/app.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dspy_cli.discovery import discover_modules
1717
from dspy_cli.discovery.gateway_finder import get_gateways_for_module, is_cron_gateway
1818
from dspy_cli.gateway import APIGateway, IdentityGateway
19+
from dspy_cli.server.executor import init_executor, shutdown_executor, DEFAULT_SYNC_WORKERS
1920
from dspy_cli.server.logging import setup_logging
2021
from dspy_cli.server.metrics import get_all_metrics, get_program_metrics_cached
2122
from dspy_cli.server.routes import create_program_routes
@@ -32,6 +33,7 @@ def create_app(
3233
logs_dir: Path,
3334
enable_ui: bool = True,
3435
enable_auth: bool = False,
36+
sync_workers: int | None = None,
3537
) -> FastAPI:
3638
"""Create and configure the FastAPI application.
3739
@@ -42,13 +44,18 @@ def create_app(
4244
logs_dir: Directory for log files
4345
enable_ui: Whether to enable the web UI (always True, kept for compatibility)
4446
enable_auth: Whether to enable API authentication via DSPY_API_KEY
47+
sync_workers: Number of threads for sync module execution (overrides config)
4548
4649
Returns:
4750
Configured FastAPI application
4851
"""
4952
# Setup logging
5053
setup_logging()
5154

55+
# Initialize bounded executor for sync module execution
56+
worker_count = sync_workers or config.get("server", {}).get("sync_worker_threads") or DEFAULT_SYNC_WORKERS
57+
init_executor(max_workers=worker_count)
58+
5259
# Create FastAPI app
5360
app = FastAPI(
5461
title="DSPy API",
@@ -349,6 +356,8 @@ async def lifespan(app: FastAPI):
349356
except Exception as e:
350357
logger.warning(f"Gateway shutdown error: {e}")
351358

359+
shutdown_executor()
360+
352361

353362
def _create_lm_instance(model_config: Dict) -> dspy.LM:
354363
"""Create a DSPy LM instance from configuration.

src/dspy_cli/server/execution.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import dspy
1111

1212
from dspy_cli.discovery import DiscoveredModule
13+
from dspy_cli.server.executor import run_sync_in_executor
1314
from dspy_cli.server.logging import log_inference
1415

1516
logger = logging.getLogger(__name__)
@@ -280,7 +281,7 @@ async def execute_pipeline(
280281
if hasattr(instance, 'aforward'):
281282
result = await instance.acall(**inputs)
282283
else:
283-
result = instance(**inputs)
284+
result = await run_sync_in_executor(instance, **inputs)
284285

285286
output = _normalize_output(result, module)
286287
duration_ms = (time.time() - start_time) * 1000
@@ -393,7 +394,9 @@ async def execute_pipeline_batch(
393394
if max_errors is not None:
394395
batch_kwargs["max_errors"] = max_errors
395396

396-
batch_result = instance.batch(examples, **batch_kwargs)
397+
batch_result = await run_sync_in_executor(
398+
instance.batch, examples, **batch_kwargs
399+
)
397400

398401
if isinstance(batch_result, tuple) and len(batch_result) == 3:
399402
successful, failed_examples, exceptions = batch_result

src/dspy_cli/server/executor.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Bounded thread pool executor for sync DSPy module execution.
2+
3+
Sync forward() calls are dispatched here so they don't block the async
4+
event loop. Context variables (including dspy.context overrides) are
5+
propagated into the worker thread automatically.
6+
"""
7+
8+
import asyncio
9+
import contextvars
10+
import functools
11+
import logging
12+
import os
13+
from concurrent.futures import ThreadPoolExecutor
14+
from typing import Any, Callable, Optional
15+
16+
logger = logging.getLogger(__name__)
17+
18+
_executor: Optional[ThreadPoolExecutor] = None
19+
20+
DEFAULT_SYNC_WORKERS = min(32, (os.cpu_count() or 1) + 4)
21+
22+
23+
def init_executor(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
24+
"""Create the process-wide bounded executor."""
25+
global _executor
26+
if _executor is not None:
27+
_executor.shutdown(wait=False)
28+
29+
workers = max_workers or DEFAULT_SYNC_WORKERS
30+
_executor = ThreadPoolExecutor(max_workers=workers, thread_name_prefix="dspy-sync")
31+
logger.info(f"Initialized sync executor with {workers} worker threads")
32+
return _executor
33+
34+
35+
def shutdown_executor() -> None:
36+
"""Shut down the executor, waiting for pending work."""
37+
global _executor
38+
if _executor is not None:
39+
_executor.shutdown(wait=True)
40+
_executor = None
41+
42+
43+
async def run_sync_in_executor(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
44+
"""Run a sync callable in the bounded executor with context propagation.
45+
46+
Falls back to the default executor if init_executor() hasn't been called.
47+
"""
48+
loop = asyncio.get_running_loop()
49+
ctx = contextvars.copy_context()
50+
func_call = functools.partial(ctx.run, fn, *args, **kwargs)
51+
return await loop.run_in_executor(_executor, func_call)

src/dspy_cli/server/runner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ENV_ENABLE_MCP = "DSPY_CLI_ENABLE_MCP"
2121
ENV_LOGS_DIR = "DSPY_CLI_LOGS_DIR"
2222
ENV_AUTH_ENABLED = "DSPY_CLI_AUTH_ENABLED"
23+
ENV_SYNC_WORKERS = "DSPY_CLI_SYNC_WORKERS"
2324

2425

2526
def _maybe_mount_mcp(app, enable: bool, *, path: str = MCP_DEFAULT_PATH, notify=None) -> bool:
@@ -86,6 +87,8 @@ def create_app_instance():
8687
logs_dir = os.environ.get(ENV_LOGS_DIR, "./logs")
8788
enable_mcp = os.environ.get(ENV_ENABLE_MCP, "false").lower() == "true"
8889
enable_auth = os.environ.get(ENV_AUTH_ENABLED, "false").lower() == "true"
90+
sync_workers_str = os.environ.get(ENV_SYNC_WORKERS)
91+
sync_workers = int(sync_workers_str) if sync_workers_str else None
8992

9093
# Validate project structure
9194
if not validate_project_structure():
@@ -118,6 +121,7 @@ def create_app_instance():
118121
logs_dir=logs_path,
119122
enable_ui=True,
120123
enable_auth=enable_auth,
124+
sync_workers=sync_workers,
121125
)
122126

123127
# Mount MCP if enabled
@@ -135,6 +139,7 @@ def main(
135139
openapi_format: str = "json",
136140
mcp: bool = False,
137141
auth: bool = False,
142+
sync_workers: int | None = None,
138143
):
139144
"""Main server execution logic.
140145
@@ -192,6 +197,7 @@ def main(
192197
logs_dir=logs_path,
193198
enable_ui=True,
194199
enable_auth=auth,
200+
sync_workers=sync_workers,
195201
)
196202

197203
# Mount MCP if enabled
@@ -276,6 +282,8 @@ def notify_cli(msg: str, level: str = "info"):
276282
os.environ[ENV_LOGS_DIR] = str(logs_path)
277283
os.environ[ENV_ENABLE_MCP] = str(mcp).lower()
278284
os.environ[ENV_AUTH_ENABLED] = str(auth).lower()
285+
if sync_workers is not None:
286+
os.environ[ENV_SYNC_WORKERS] = str(sync_workers)
279287

280288
# Get project root and src directory for watching
281289
project_root = Path.cwd()
@@ -319,6 +327,7 @@ def notify_cli(msg: str, level: str = "info"):
319327
parser.add_argument("--openapi-format", choices=["json", "yaml"], default="json")
320328
parser.add_argument("--mcp", action="store_true", help="Enable MCP server at /mcp")
321329
parser.add_argument("--auth", action="store_true", help="Enable API authentication")
330+
parser.add_argument("--sync-workers", type=int, default=None, help="Number of sync worker threads")
322331
args = parser.parse_args()
323332

324333
main(
@@ -330,4 +339,5 @@ def notify_cli(msg: str, level: str = "info"):
330339
openapi_format=args.openapi_format,
331340
mcp=args.mcp,
332341
auth=args.auth,
342+
sync_workers=args.sync_workers,
333343
)

tests/test_commands_smoke.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def fake_runner_main(**kwargs):
123123
"openapi_format": "yaml",
124124
"mcp": False,
125125
"auth": False,
126+
"sync_workers": None,
126127
}
127128

128129

tests/test_executor.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Tests for the bounded executor and context propagation."""
2+
3+
import asyncio
4+
import contextvars
5+
6+
import pytest
7+
8+
from dspy_cli.server.executor import (
9+
init_executor,
10+
run_sync_in_executor,
11+
shutdown_executor,
12+
)
13+
14+
15+
@pytest.fixture(autouse=True)
16+
def _clean_executor():
17+
"""Ensure each test gets a fresh executor."""
18+
yield
19+
shutdown_executor()
20+
21+
22+
class TestContextPropagation:
23+
24+
def test_contextvar_propagates_to_executor_thread(self):
25+
cv = contextvars.ContextVar("test_cv", default="UNSET")
26+
init_executor(max_workers=2)
27+
28+
def read_cv():
29+
return cv.get()
30+
31+
async def run():
32+
cv.set("per-request-value")
33+
return await run_sync_in_executor(read_cv)
34+
35+
result = asyncio.get_event_loop().run_until_complete(run())
36+
assert result == "per-request-value"
37+
38+
def test_concurrent_requests_see_own_context(self):
39+
cv = contextvars.ContextVar("test_cv", default="UNSET")
40+
init_executor(max_workers=4)
41+
42+
results = {}
43+
44+
def read_cv():
45+
import time
46+
time.sleep(0.05)
47+
return cv.get()
48+
49+
async def make_request(name: str, value: str):
50+
cv.set(value)
51+
results[name] = await run_sync_in_executor(read_cv)
52+
53+
async def run():
54+
await asyncio.gather(
55+
make_request("a", "alpha"),
56+
make_request("b", "beta"),
57+
make_request("c", "gamma"),
58+
)
59+
60+
asyncio.get_event_loop().run_until_complete(run())
61+
assert results == {"a": "alpha", "b": "beta", "c": "gamma"}
62+
63+
def test_dspy_context_lm_propagates(self):
64+
import dspy
65+
66+
init_executor(max_workers=2)
67+
68+
def read_lm():
69+
return dspy.settings.lm
70+
71+
async def run():
72+
sentinel = object()
73+
with dspy.context(lm=sentinel):
74+
result = await run_sync_in_executor(read_lm)
75+
return result, sentinel
76+
77+
result, sentinel = asyncio.get_event_loop().run_until_complete(run())
78+
assert result is sentinel
79+
80+
def test_fallback_without_init(self):
81+
cv = contextvars.ContextVar("test_cv", default="UNSET")
82+
83+
def read_cv():
84+
return cv.get()
85+
86+
async def run():
87+
cv.set("fallback-value")
88+
return await run_sync_in_executor(read_cv)
89+
90+
result = asyncio.get_event_loop().run_until_complete(run())
91+
assert result == "fallback-value"

0 commit comments

Comments
 (0)