diff --git a/src/sktime_mcp/runtime/executor.py b/src/sktime_mcp/runtime/executor.py index 35a2ab21..67c79452 100644 --- a/src/sktime_mcp/runtime/executor.py +++ b/src/sktime_mcp/runtime/executor.py @@ -184,14 +184,94 @@ def predict( except Exception as e: return {"success": False, "error": str(e)} + def predict_interval( + self, + handle_id: str, + fh: Optional[Union[int, list[int]]] = None, + X: Optional[Any] = None, + coverage: Union[float, list[float]] = 0.9, + ) -> dict[str, Any]: + """Generate predictions with prediction intervals. + + Args: + handle_id: Estimator handle + fh: Forecast horizon + X: Optional exogenous variables + coverage: Confidence level(s) for prediction intervals. + Can be a single float (e.g., 0.9 for 90% interval) or + a list of floats for multiple intervals. + + Returns: + Dictionary with success status, predictions, and intervals + """ + try: + instance = self._handle_manager.get_instance(handle_id) + except KeyError: + return {"success": False, "error": f"Handle not found: {handle_id}"} + + if not self._handle_manager.is_fitted(handle_id): + return {"success": False, "error": "Estimator not fitted"} + + try: + if fh is None: + fh = list(range(1, 13)) + + # Check if estimator supports prediction intervals + if not hasattr(instance, "predict_interval"): + return { + "success": False, + "error": "Estimator does not support prediction intervals. " + "Use a probabilistic forecaster (check capability:pred_int tag).", + } + + pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage) + + # Convert to JSON-serializable format + # The result is a DataFrame with MultiIndex columns: (variable, coverage, lower/upper) + pred_intervals_copy = pred_intervals.copy() + pred_intervals_copy.index = pred_intervals_copy.index.astype(str) + + # Extract intervals into a structured format + intervals_dict = {} + for col in pred_intervals_copy.columns: + var_name = col[0] if isinstance(col, tuple) else "predictions" + cov_level = col[1] if isinstance(col, tuple) else coverage + bound_type = col[2] if isinstance(col, tuple) else "bound" + + key = f"{var_name}_{cov_level}" + if key not in intervals_dict: + intervals_dict[key] = {"coverage": cov_level} + intervals_dict[key][bound_type] = pred_intervals_copy[col].to_dict() + + return { + "success": True, + "predictions": pred_intervals_copy.to_dict(), + "intervals": intervals_dict, + "horizon": len(fh) if hasattr(fh, "__len__") else fh, + "coverage": coverage, + } + except Exception as e: + return {"success": False, "error": str(e)} + def fit_predict( self, handle_id: str, dataset: str, horizon: int = 12, data_handle: Optional[str] = None, + coverage: Optional[Union[float, list[float]]] = None, ) -> dict[str, Any]: - """Convenience method: load data, fit, and predict.""" + """Convenience method: load data, fit, and predict. + + Args: + handle_id: Estimator handle + dataset: Dataset name + horizon: Forecast horizon + data_handle: Optional data handle for custom data + coverage: Optional confidence level(s) for prediction intervals. + If provided, uses predict_interval() instead of predict(). + Can be a single float (e.g., 0.9) or list of floats. + """ if data_handle is not None: # Use custom loaded data if data_handle not in self._data_handles: @@ -217,6 +297,9 @@ def fit_predict( if not fit_result["success"]: return fit_result + # Use predict_interval if coverage is provided, otherwise use predict + if coverage is not None: + return self.predict_interval(handle_id, fh=fh, X=X, coverage=coverage) return self.predict(handle_id, fh=fh, X=X) async def fit_predict_async( @@ -225,6 +308,7 @@ async def fit_predict_async( dataset: str, horizon: int = 12, job_id: Optional[str] = None, + coverage: Optional[Union[float, list[float]]] = None, ) -> dict[str, Any]: """ Async version of fit_predict with job tracking. @@ -237,6 +321,8 @@ async def fit_predict_async( dataset: Dataset name horizon: Forecast horizon job_id: Optional job ID for tracking (created if not provided) + coverage: Optional confidence level(s) for prediction intervals. + If provided, uses predict_interval() instead of predict(). Returns: Dictionary with success status and job_id @@ -304,17 +390,25 @@ async def fit_predict_async( return fit_result # Step 3: Generate predictions + step_msg = f"Generating predictions (horizon={horizon})" + if coverage is not None: + step_msg += f" with coverage={coverage}" self._job_manager.update_job( job_id, completed_steps=2, - current_step=f"Generating predictions (horizon={horizon})...", + current_step=step_msg, ) await asyncio.sleep(0.01) # Yield control - # Run predict in executor - predict_result = await loop.run_in_executor( - None, lambda: self.predict(handle_id, fh=fh, X=X) - ) + # Run predict or predict_interval in executor + if coverage is not None: + predict_result = await loop.run_in_executor( + None, lambda: self.predict_interval(handle_id, fh=fh, X=X, coverage=coverage) + ) + else: + predict_result = await loop.run_in_executor( + None, lambda: self.predict(handle_id, fh=fh, X=X) + ) if not predict_result["success"]: self._job_manager.update_job( diff --git a/src/sktime_mcp/server.py b/src/sktime_mcp/server.py index e377ff2f..c252ff00 100644 --- a/src/sktime_mcp/server.py +++ b/src/sktime_mcp/server.py @@ -24,7 +24,7 @@ release_data_handle_tool, ) from sktime_mcp.tools.describe_estimator import describe_estimator_tool -from sktime_mcp.tools.evaluate import evaluate_estimator_tool +from sktime_mcp.tools.evaluate import compare_estimators_tool, evaluate_estimator_tool from sktime_mcp.tools.fit_predict import ( fit_predict_async_tool, fit_predict_tool, @@ -244,7 +244,8 @@ async def list_tools() -> list[Tool]: name="fit_predict", description=( "Fit an estimator on a dataset and generate predictions. " - "Accepts either a demo dataset name or a data_handle from load_data_source." + "Accepts either a demo dataset name or a data_handle from load_data_source. " + "Use the coverage parameter to get prediction intervals." ), inputSchema={ "type": "object", @@ -269,6 +270,16 @@ async def list_tools() -> list[Tool]: "description": "Forecast horizon (default: 12)", "default": 12, }, + "coverage": { + "type": ["number", "array"], + "description": ( + "Confidence level(s) for prediction intervals. " + "Use a single number (e.g., 0.9 for 90% intervals) or " + "a list of numbers for multiple intervals (e.g., [0.5, 0.9, 0.95]). " + "Only works with estimators that support prediction intervals " + "(check capability:pred_int tag)." + ), + }, }, "required": ["estimator_handle"], }, @@ -277,7 +288,8 @@ async def list_tools() -> list[Tool]: name="fit_predict_async", description=( "Fit an estimator on a dataset and generate predictions " - "(non-blocking background job). Returns a job_id." + "(non-blocking background job). Returns a job_id. " + "Use the coverage parameter to get prediction intervals." ), inputSchema={ "type": "object", @@ -295,6 +307,16 @@ async def list_tools() -> list[Tool]: "description": "Forecast horizon (default: 12)", "default": 12, }, + "coverage": { + "type": ["number", "array"], + "description": ( + "Confidence level(s) for prediction intervals. " + "Use a single number (e.g., 0.9 for 90% intervals) or " + "a list of numbers for multiple intervals (e.g., [0.5, 0.9, 0.95]). " + "Only works with estimators that support prediction intervals " + "(check capability:pred_int tag)." + ), + }, }, "required": ["estimator_handle", "dataset"], }, @@ -322,6 +344,59 @@ async def list_tools() -> list[Tool]: "required": ["estimator_handle", "dataset"], }, ), + Tool( + name="compare_estimators", + description=( + "Compare multiple estimators on the same dataset using cross-validation. " + "Runs CV on each estimator with the same dataset and metric, " + "then returns results ranked by performance. " + "Use this for agentic model selection - run multiple candidates " + "and automatically pick the best one." + ), + inputSchema={ + "type": "object", + "properties": { + "estimator_handles": { + "type": "array", + "items": {"type": "string"}, + "description": "List of estimator handles from instantiate_estimator", + }, + "dataset": { + "type": "string", + "description": ( + "Dataset name: airline, sunspots, lynx, etc. " + "Use this or data_handle, not both." + ), + }, + "data_handle": { + "type": "string", + "description": ( + "Handle from load_data_source for custom data. " + "Use this or dataset, not both." + ), + }, + "metric": { + "type": "string", + "description": ( + "Metric to use for comparison: MAE, MAPE, MSE, RMSE, SMAPE, MASE, MedAE " + "(default: MAPE)" + ), + "default": "MAPE", + }, + "cv_folds": { + "type": "integer", + "description": "Number of cross-validation folds (default: 3)", + "default": 3, + }, + "horizon": { + "type": "integer", + "description": "Forecast horizon for evaluation (default: 12)", + "default": 12, + }, + }, + "required": ["estimator_handles"], + }, + ), # -- Data ------------------------------------------------------------ Tool( name="list_available_data", @@ -630,6 +705,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: arguments.get("dataset", ""), arguments.get("horizon", 12), data_handle=arguments.get("data_handle"), + coverage=arguments.get("coverage"), ) result = sanitize_for_json(result) @@ -638,6 +714,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: arguments["estimator_handle"], arguments["dataset"], arguments.get("horizon", 12), + coverage=arguments.get("coverage"), ) elif name == "fit_predict_with_data": @@ -658,6 +735,17 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: ) result = sanitize_for_json(result) + elif name == "compare_estimators": + result = compare_estimators_tool( + arguments["estimator_handles"], + dataset=arguments.get("dataset"), + data_handle=arguments.get("data_handle"), + metric=arguments.get("metric", "MAPE"), + cv_folds=arguments.get("cv_folds", 3), + horizon=arguments.get("horizon", 12), + ) + result = sanitize_for_json(result) + # -- Data ------------------------------------------------------------ elif name == "list_available_data": result = list_available_data_tool(arguments.get("is_demo")) diff --git a/src/sktime_mcp/tools/evaluate.py b/src/sktime_mcp/tools/evaluate.py index 3ba8678a..c31adc17 100644 --- a/src/sktime_mcp/tools/evaluate.py +++ b/src/sktime_mcp/tools/evaluate.py @@ -5,16 +5,103 @@ """ import logging -from typing import Any +from typing import Any, Optional, Union from sktime.forecasting.model_evaluation import evaluate from sktime.forecasting.model_selection import ExpandingWindowSplitter +from sktime.performance_metrics.forecasting import ( + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredError, + MedianAbsoluteError, + GeometricMeanAbsolutePercentageError, +) from sktime_mcp.runtime.executor import get_executor logger = logging.getLogger(__name__) +# Mapping of metric names to sktime metric classes +METRIC_NAME_TO_CLASS = { + "MAE": MeanAbsoluteError, + "MSE": MeanSquaredError, + "RMSE": lambda: MeanSquaredError(square_root=True), + "MAPE": MeanAbsolutePercentageError, + "SMAPE": GeometricMeanAbsolutePercentageError, + "MASE": None, # Requires seasonality parameter, handled specially + "MedAE": MedianAbsoluteError, +} + + +def _get_metric_instance(metric_name: str) -> Any: + """Get a sktime metric instance from a metric name. + + Args: + metric_name: Name of the metric (MAE, MAPE, MSE, RMSE, SMAPE, MASE, etc.) + + Returns: + An instance of the appropriate sktime metric class + + Raises: + ValueError: If the metric is not supported + """ + metric_name = metric_name.upper() + + if metric_name == "MASE": + # MASE requires a seasonality parameter - use default sp=1 + from sktime.performance_metrics.forecasting import MeanAbsoluteScaledError + return MeanAbsoluteScaledError(sp=1) + + if metric_name not in METRIC_NAME_TO_CLASS: + available = list(METRIC_NAME_TO_CLASS.keys()) + ["MASE"] + raise ValueError( + f"Unknown metric: {metric_name}. Available metrics: {available}" + ) + + metric_class_or_func = METRIC_NAME_TO_CLASS[metric_name] + + if callable(metric_class_or_func): + return metric_class_or_func() + return metric_class_or_func() + + +def _extract_metric_value(results: Any, metric_name: str) -> Optional[float]: + """Extract the metric value from evaluate results. + + The column name format is 'test_' (e.g., 'test_MeanAbsolutePercentageError'). + + Args: + results: DataFrame from sktime evaluate() + metric_name: Name of the metric to extract + + Returns: + The metric value as a float, or None if not found + """ + # Convert column names to find the matching metric + metric_col = None + for col in results.columns: + if col == metric_name or col == f"test_{metric_name}": + metric_col = col + break + # Check case-insensitive match + col_lower = col.lower() + metric_lower = metric_name.lower() + if metric_lower in col_lower or col_lower == f"test_{metric_lower}": + metric_col = col + break + + if metric_col is None: + return None + + # Get the mean of the metric across folds + values = results[metric_col].dropna() + if len(values) == 0: + return None + + return float(values.mean()) + + def evaluate_estimator_tool( estimator_handle: str, dataset: str, @@ -71,3 +158,180 @@ def evaluate_estimator_tool( except Exception as e: logger.exception("Error during evaluate") return {"success": False, "error": str(e)} + + +def compare_estimators_tool( + estimator_handles: list[str], + dataset: Optional[str] = None, + data_handle: Optional[str] = None, + metric: str = "MAPE", + cv_folds: int = 3, + horizon: int = 12, +) -> dict[str, Any]: + """ + Compare multiple estimators on the same dataset using cross-validation. + + Runs cross-validation on each estimator with the same dataset and metric, + then returns results ranked by performance. + + Args: + estimator_handles: List of estimator handles from instantiate_estimator + dataset: Name of demo dataset (e.g., "airline", "sunspots") + data_handle: Alternative to dataset - handle from load_data_source for custom data + metric: Metric to use for comparison (MAE, MAPE, MSE, RMSE, SMAPE, MASE, MedAE) + cv_folds: Number of cross-validation folds + horizon: Forecast horizon for evaluation + + Returns: + Dictionary with: + - success: bool + - ranked: List of dicts with rank, handle, estimator name, score + - best_handle: Handle of the best estimator + - best_estimator: Name of the best estimator + - metric: The metric used for comparison + - cv_folds: Number of folds run + - horizon: Forecast horizon used + + Example: + >>> compare_estimators_tool(["est_abc", "est_def"], dataset="airline", metric="MAPE") + { + "success": True, + "ranked": [ + {"rank": 1, "handle": "est_abc", "estimator": "AutoARIMA", "score": 4.2}, + {"rank": 2, "handle": "est_def", "estimator": "ARIMA", "score": 7.1} + ], + "best_handle": "est_abc", + "best_estimator": "AutoARIMA", + "metric": "MAPE", + "cv_folds": 3, + "horizon": 12 + } + """ + executor = get_executor() + + # Validate inputs + if not estimator_handles: + return {"success": False, "error": "At least one estimator handle is required"} + + if not dataset and not data_handle: + return { + "success": False, + "error": "Either dataset or data_handle must be provided" + } + + # Load data + if data_handle: + if data_handle not in executor._data_handles: + return { + "success": False, + "error": f"Unknown data handle: {data_handle}", + "available_handles": list(executor._data_handles.keys()), + } + data_info = executor._data_handles[data_handle] + y = data_info["y"] + X = data_info.get("X") + else: + data_result = executor.load_dataset(dataset) + if not data_result["success"]: + return data_result + y = data_result["data"] + X = data_result.get("exog") + + # Create CV strategy + try: + n = len(y) + initial_window = max(int(n * 0.5), n - cv_folds * 2) + if initial_window < 1: + initial_window = 1 + cv = ExpandingWindowSplitter( + initial_window=initial_window, + step_length=1, + fh=list(range(1, horizon + 1)), + ) + except Exception as e: + return {"success": False, "error": f"Error creating CV splitter: {str(e)}"} + + # Get metric instance + try: + scoring = _get_metric_instance(metric) + except ValueError as e: + return {"success": False, "error": str(e)} + + # Evaluate each estimator + results = [] + errors = [] + + for handle in estimator_handles: + try: + # Get estimator instance + try: + instance = executor._handle_manager.get_instance(handle) + except KeyError: + errors.append({"handle": handle, "error": f"Handle not found: {handle}"}) + continue + + # Get estimator name for display + try: + handle_info = executor._handle_manager.get_info(handle) + estimator_name = handle_info.estimator_name + except Exception: + estimator_name = "Unknown" + + # Run evaluation + eval_results = evaluate( + forecaster=instance, + y=y, + X=X, + cv=cv, + scoring=scoring, + ) + + # Extract metric value + score = _extract_metric_value(eval_results, metric) + if score is None: + errors.append({ + "handle": handle, + "estimator": estimator_name, + "error": f"Could not extract {metric} from results. Available columns: {list(eval_results.columns)}" + }) + continue + + results.append({ + "handle": handle, + "estimator": estimator_name, + "score": score, + }) + + except Exception as e: + errors.append({"handle": handle, "error": str(e)}) + + if not results: + return { + "success": False, + "error": "No estimators could be evaluated successfully", + "errors": errors, + } + + # Sort by score (ascending - lower is better for most metrics) + results.sort(key=lambda x: x["score"]) + + # Add ranks + ranked = [] + for i, result in enumerate(results, 1): + ranked.append({ + "rank": i, + "handle": result["handle"], + "estimator": result["estimator"], + "score": round(result["score"], 4), + }) + + return { + "success": True, + "ranked": ranked, + "best_handle": ranked[0]["handle"], + "best_estimator": ranked[0]["estimator"], + "metric": metric.upper(), + "cv_folds": cv_folds, + "horizon": horizon, + "errors": errors if errors else None, + } diff --git a/src/sktime_mcp/tools/fit_predict.py b/src/sktime_mcp/tools/fit_predict.py index de9c1b5f..ee6fe872 100644 --- a/src/sktime_mcp/tools/fit_predict.py +++ b/src/sktime_mcp/tools/fit_predict.py @@ -6,7 +6,7 @@ import asyncio import logging -from typing import Any, Optional +from typing import Any, Optional, Union from sktime_mcp.runtime.executor import get_executor @@ -18,6 +18,7 @@ def fit_predict_tool( dataset: str, horizon: int = 12, data_handle: Optional[str] = None, + coverage: Optional[Union[float, list[float]]] = None, ) -> dict[str, Any]: """ Execute a complete fit-predict workflow. @@ -27,12 +28,17 @@ def fit_predict_tool( dataset: Name of demo dataset (e.g., "airline", "sunspots") horizon: Forecast horizon (default: 12) data_handle: Optional handle from load_data_source for custom data + coverage: Optional confidence level(s) for prediction intervals. + Can be a single float (e.g., 0.9 for 90% interval) or + a list of floats for multiple intervals. Returns: Dictionary with: - success: bool - predictions: Forecast values - horizon: Number of steps predicted + - intervals: (if coverage provided) Prediction intervals with lower/upper bounds + - coverage: (if coverage provided) The coverage level(s) used Example: >>> fit_predict_tool("est_abc123", "airline", horizon=12) @@ -41,9 +47,24 @@ def fit_predict_tool( "predictions": {1: 450.2, 2: 460.5, ...}, "horizon": 12 } + + >>> fit_predict_tool("est_abc123", "airline", horizon=12, coverage=0.9) + { + "success": True, + "predictions": {...}, + "intervals": {"Airline": {"coverage": 0.9, "lower": {...}, "upper": {...}}}, + "horizon": 12, + "coverage": 0.9 + } """ executor = get_executor() - return executor.fit_predict(estimator_handle, dataset, horizon, data_handle=data_handle) + return executor.fit_predict( + estimator_handle, + dataset, + horizon, + data_handle=data_handle, + coverage=coverage, + ) def fit_tool( @@ -109,6 +130,7 @@ def fit_predict_async_tool( estimator_handle: str, dataset: str, horizon: int = 12, + coverage: Optional[Union[float, list[float]]] = None, ) -> dict[str, Any]: """ Execute a fit-predict workflow in the background (non-blocking). @@ -120,6 +142,9 @@ def fit_predict_async_tool( estimator_handle: Handle from instantiate_estimator dataset: Name of demo dataset (e.g., "airline", "sunspots") horizon: Forecast horizon (default: 12) + coverage: Optional confidence level(s) for prediction intervals. + Can be a single float (e.g., 0.9 for 90% interval) or + a list of floats for multiple intervals. Returns: Dictionary with: @@ -134,6 +159,14 @@ def fit_predict_async_tool( "job_id": "abc-123-def-456", "message": "Training job started. Use check_job_status to monitor progress." } + + >>> fit_predict_async_tool("est_abc123", "airline", horizon=12, coverage=0.9) + { + "success": True, + "job_id": "abc-123-def-456", + "message": "Training job started for ThetaForecaster on airline. Use check_job_status('abc-123-def-456') to monitor progress.", + "coverage": 0.9 + } """ from sktime_mcp.runtime.jobs import get_job_manager @@ -168,10 +201,10 @@ def fit_predict_async_tool( asyncio.set_event_loop(loop) # Schedule the coroutine (non-blocking!) - coro = executor.fit_predict_async(estimator_handle, dataset, horizon, job_id) + coro = executor.fit_predict_async(estimator_handle, dataset, horizon, job_id, coverage) asyncio.run_coroutine_threadsafe(coro, loop) - return { + result = { "success": True, "job_id": job_id, "message": f"Training job started for {estimator_name} on {dataset}. Use check_job_status('{job_id}') to monitor progress.", @@ -179,3 +212,8 @@ def fit_predict_async_tool( "dataset": dataset, "horizon": horizon, } + + if coverage is not None: + result["coverage"] = coverage + + return result