Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 100 additions & 6 deletions src/sktime_mcp/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,94 @@ def predict(
except Exception as e:
return {"success": False, "error": str(e)}

def predict_interval(
self,
handle_id: str,
fh: Optional[Union[int, list[int]]] = None,
X: Optional[Any] = None,
coverage: Union[float, list[float]] = 0.9,
) -> dict[str, Any]:
"""Generate predictions with prediction intervals.

Args:
handle_id: Estimator handle
fh: Forecast horizon
X: Optional exogenous variables
coverage: Confidence level(s) for prediction intervals.
Can be a single float (e.g., 0.9 for 90% interval) or
a list of floats for multiple intervals.

Returns:
Dictionary with success status, predictions, and intervals
"""
try:
instance = self._handle_manager.get_instance(handle_id)
except KeyError:
return {"success": False, "error": f"Handle not found: {handle_id}"}

if not self._handle_manager.is_fitted(handle_id):
return {"success": False, "error": "Estimator not fitted"}

try:
if fh is None:
fh = list(range(1, 13))

# Check if estimator supports prediction intervals
if not hasattr(instance, "predict_interval"):
return {
"success": False,
"error": "Estimator does not support prediction intervals. "
"Use a probabilistic forecaster (check capability:pred_int tag).",
}

pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage)
Comment on lines +220 to +227
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict_interval() checks support via hasattr(instance, "predict_interval"), but most sktime forecasters inherit this method even when they don't support intervals (capability tag false), so this check can be ineffective and lead to confusing runtime errors. Prefer checking the estimator tag (e.g., get_tag("capability:pred_int")) and/or catching NotImplementedError to return the intended "does not support" error.

Suggested change
if not hasattr(instance, "predict_interval"):
return {
"success": False,
"error": "Estimator does not support prediction intervals. "
"Use a probabilistic forecaster (check capability:pred_int tag).",
}
pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage)
supports_pred_int = True
if hasattr(instance, "get_tag"):
supports_pred_int = instance.get_tag("capability:pred_int", False)
if not supports_pred_int:
return {
"success": False,
"error": "Estimator does not support prediction intervals. "
"Use a probabilistic forecaster (check capability:pred_int tag).",
}
try:
pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage)
except NotImplementedError:
return {
"success": False,
"error": "Estimator does not support prediction intervals. "
"Use a probabilistic forecaster (check capability:pred_int tag).",
}

Copilot uses AI. Check for mistakes.

# Convert to JSON-serializable format
# The result is a DataFrame with MultiIndex columns: (variable, coverage, lower/upper)
pred_intervals_copy = pred_intervals.copy()
pred_intervals_copy.index = pred_intervals_copy.index.astype(str)

# Extract intervals into a structured format
intervals_dict = {}
for col in pred_intervals_copy.columns:
var_name = col[0] if isinstance(col, tuple) else "predictions"
cov_level = col[1] if isinstance(col, tuple) else coverage
bound_type = col[2] if isinstance(col, tuple) else "bound"

key = f"{var_name}_{cov_level}"
if key not in intervals_dict:
intervals_dict[key] = {"coverage": cov_level}
intervals_dict[key][bound_type] = pred_intervals_copy[col].to_dict()

return {
"success": True,
"predictions": pred_intervals_copy.to_dict(),
"intervals": intervals_dict,
"horizon": len(fh) if hasattr(fh, "__len__") else fh,
"coverage": coverage,
}
Comment on lines +234 to +252
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict_interval() currently returns pred_intervals_copy.to_dict() under the predictions key, which is not point forecasts and also changes the output shape compared to predict(). This conflicts with fit_predict_tool's docs/contract that predictions are forecast values and intervals contains bounds. Consider returning point predictions from predict() plus a separate intervals structure, and align the interval key structure with the tool docs (e.g., avoid embedding coverage into the dict key unless documented).

Copilot uses AI. Check for mistakes.
except Exception as e:
return {"success": False, "error": str(e)}

def fit_predict(
self,
handle_id: str,
dataset: str,
horizon: int = 12,
data_handle: Optional[str] = None,
coverage: Optional[Union[float, list[float]]] = None,
) -> dict[str, Any]:
"""Convenience method: load data, fit, and predict."""
"""Convenience method: load data, fit, and predict.

Args:
handle_id: Estimator handle
dataset: Dataset name
horizon: Forecast horizon
data_handle: Optional data handle for custom data
coverage: Optional confidence level(s) for prediction intervals.
If provided, uses predict_interval() instead of predict().
Can be a single float (e.g., 0.9) or list of floats.
"""
if data_handle is not None:
# Use custom loaded data
if data_handle not in self._data_handles:
Expand All @@ -217,6 +297,9 @@ def fit_predict(
if not fit_result["success"]:
return fit_result

# Use predict_interval if coverage is provided, otherwise use predict
if coverage is not None:
return self.predict_interval(handle_id, fh=fh, X=X, coverage=coverage)
return self.predict(handle_id, fh=fh, X=X)
Comment on lines 187 to 303
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new coverage/prediction-interval path (Executor.predict_interval + fit_predict(..., coverage=...)) isn't exercised by existing tests. Adding a test that instantiates a forecaster with capability:pred_int and asserts the returned JSON structure (point predictions + lower/upper bounds for the requested coverage) would help ensure this API remains stable.

Copilot uses AI. Check for mistakes.

async def fit_predict_async(
Expand All @@ -225,6 +308,7 @@ async def fit_predict_async(
dataset: str,
horizon: int = 12,
job_id: Optional[str] = None,
coverage: Optional[Union[float, list[float]]] = None,
) -> dict[str, Any]:
"""
Async version of fit_predict with job tracking.
Expand All @@ -237,6 +321,8 @@ async def fit_predict_async(
dataset: Dataset name
horizon: Forecast horizon
job_id: Optional job ID for tracking (created if not provided)
coverage: Optional confidence level(s) for prediction intervals.
If provided, uses predict_interval() instead of predict().

Returns:
Dictionary with success status and job_id
Expand Down Expand Up @@ -304,17 +390,25 @@ async def fit_predict_async(
return fit_result

# Step 3: Generate predictions
step_msg = f"Generating predictions (horizon={horizon})"
if coverage is not None:
step_msg += f" with coverage={coverage}"
self._job_manager.update_job(
job_id,
completed_steps=2,
current_step=f"Generating predictions (horizon={horizon})...",
current_step=step_msg,
)
await asyncio.sleep(0.01) # Yield control

# Run predict in executor
predict_result = await loop.run_in_executor(
None, lambda: self.predict(handle_id, fh=fh, X=X)
)
# Run predict or predict_interval in executor
if coverage is not None:
predict_result = await loop.run_in_executor(
None, lambda: self.predict_interval(handle_id, fh=fh, X=X, coverage=coverage)
)
else:
predict_result = await loop.run_in_executor(
None, lambda: self.predict(handle_id, fh=fh, X=X)
)

if not predict_result["success"]:
self._job_manager.update_job(
Expand Down
94 changes: 91 additions & 3 deletions src/sktime_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
release_data_handle_tool,
)
from sktime_mcp.tools.describe_estimator import describe_estimator_tool
from sktime_mcp.tools.evaluate import evaluate_estimator_tool
from sktime_mcp.tools.evaluate import compare_estimators_tool, evaluate_estimator_tool
from sktime_mcp.tools.fit_predict import (
fit_predict_async_tool,
fit_predict_tool,
Expand Down Expand Up @@ -244,7 +244,8 @@ async def list_tools() -> list[Tool]:
name="fit_predict",
description=(
"Fit an estimator on a dataset and generate predictions. "
"Accepts either a demo dataset name or a data_handle from load_data_source."
"Accepts either a demo dataset name or a data_handle from load_data_source. "
"Use the coverage parameter to get prediction intervals."
),
inputSchema={
"type": "object",
Expand All @@ -269,6 +270,16 @@ async def list_tools() -> list[Tool]:
"description": "Forecast horizon (default: 12)",
"default": 12,
},
"coverage": {
"type": ["number", "array"],
"description": (
"Confidence level(s) for prediction intervals. "
"Use a single number (e.g., 0.9 for 90% intervals) or "
"a list of numbers for multiple intervals (e.g., [0.5, 0.9, 0.95]). "
"Only works with estimators that support prediction intervals "
"(check capability:pred_int tag)."
),
},
Comment on lines +273 to +282
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the MCP JSON schema, coverage allows type "array" but doesn't specify items (e.g., {"type": "number"}). Adding an items schema will prevent agents from passing arrays of non-numbers and makes the tool contract clearer.

Copilot uses AI. Check for mistakes.
},
"required": ["estimator_handle"],
},
Expand All @@ -277,7 +288,8 @@ async def list_tools() -> list[Tool]:
name="fit_predict_async",
description=(
"Fit an estimator on a dataset and generate predictions "
"(non-blocking background job). Returns a job_id."
"(non-blocking background job). Returns a job_id. "
"Use the coverage parameter to get prediction intervals."
),
inputSchema={
"type": "object",
Expand All @@ -295,6 +307,16 @@ async def list_tools() -> list[Tool]:
"description": "Forecast horizon (default: 12)",
"default": 12,
},
"coverage": {
"type": ["number", "array"],
"description": (
"Confidence level(s) for prediction intervals. "
"Use a single number (e.g., 0.9 for 90% intervals) or "
"a list of numbers for multiple intervals (e.g., [0.5, 0.9, 0.95]). "
"Only works with estimators that support prediction intervals "
"(check capability:pred_int tag)."
),
},
},
"required": ["estimator_handle", "dataset"],
},
Expand Down Expand Up @@ -322,6 +344,59 @@ async def list_tools() -> list[Tool]:
"required": ["estimator_handle", "dataset"],
},
),
Tool(
name="compare_estimators",
description=(
"Compare multiple estimators on the same dataset using cross-validation. "
"Runs CV on each estimator with the same dataset and metric, "
"then returns results ranked by performance. "
"Use this for agentic model selection - run multiple candidates "
"and automatically pick the best one."
),
inputSchema={
"type": "object",
"properties": {
"estimator_handles": {
"type": "array",
"items": {"type": "string"},
"description": "List of estimator handles from instantiate_estimator",
},
"dataset": {
"type": "string",
"description": (
"Dataset name: airline, sunspots, lynx, etc. "
"Use this or data_handle, not both."
),
},
"data_handle": {
"type": "string",
"description": (
"Handle from load_data_source for custom data. "
"Use this or dataset, not both."
),
},
"metric": {
"type": "string",
"description": (
"Metric to use for comparison: MAE, MAPE, MSE, RMSE, SMAPE, MASE, MedAE "
"(default: MAPE)"
),
"default": "MAPE",
},
"cv_folds": {
"type": "integer",
"description": "Number of cross-validation folds (default: 3)",
"default": 3,
},
"horizon": {
"type": "integer",
"description": "Forecast horizon for evaluation (default: 12)",
"default": 12,
},
},
"required": ["estimator_handles"],
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compare_estimators MCP schema doesn't enforce that exactly one of dataset or data_handle is provided (and not both), even though the description says so and the implementation errors when neither is present. Consider expressing this via JSON Schema (e.g., oneOf/anyOf with required fields + not both) so agents can discover valid calls without trial-and-error.

Suggested change
"required": ["estimator_handles"],
"required": ["estimator_handles"],
"oneOf": [
{
"required": ["dataset"],
"not": {"required": ["data_handle"]},
},
{
"required": ["data_handle"],
"not": {"required": ["dataset"]},
},
],

Copilot uses AI. Check for mistakes.
},
),
# -- Data ------------------------------------------------------------
Tool(
name="list_available_data",
Expand Down Expand Up @@ -630,6 +705,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
arguments.get("dataset", ""),
arguments.get("horizon", 12),
data_handle=arguments.get("data_handle"),
coverage=arguments.get("coverage"),
)
result = sanitize_for_json(result)

Expand All @@ -638,6 +714,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
arguments["estimator_handle"],
arguments["dataset"],
arguments.get("horizon", 12),
coverage=arguments.get("coverage"),
)

elif name == "fit_predict_with_data":
Expand All @@ -658,6 +735,17 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
)
result = sanitize_for_json(result)

elif name == "compare_estimators":
result = compare_estimators_tool(
arguments["estimator_handles"],
dataset=arguments.get("dataset"),
data_handle=arguments.get("data_handle"),
metric=arguments.get("metric", "MAPE"),
cv_folds=arguments.get("cv_folds", 3),
horizon=arguments.get("horizon", 12),
)
result = sanitize_for_json(result)

# -- Data ------------------------------------------------------------
elif name == "list_available_data":
result = list_available_data_tool(arguments.get("is_demo"))
Expand Down
Loading
Loading