feat: add compare_estimators tool for agentic model selection#236
feat: add compare_estimators tool for agentic model selection#236himax12 wants to merge 2 commits intosktime:mainfrom
Conversation
This fix addresses issue sktime#176 where the coverage parameter was silently dropped in async workflows, preventing prediction intervals from being returned. Changes: - Add predict_interval() method to Executor class that uses sktime's predict_interval() for probabilistic forecasting - Add coverage parameter to Executor.fit_predict() - when provided, uses predict_interval() instead of predict() - Add coverage parameter to Executor.fit_predict_async() - same behavior - Update fit_predict_tool() to accept and pass coverage parameter - Update fit_predict_async_tool() to accept and pass coverage parameter - Update MCP server tool schemas for fit_predict and fit_predict_async to expose coverage parameter in the API The coverage parameter accepts: - A single float (e.g., 0.9 for 90% prediction intervals) - A list of floats for multiple interval levels (e.g., [0.5, 0.9, 0.95]) Only works with estimators that support prediction intervals (check capability:pred_int tag).
This implements issue sktime#178 - a new compare_estimators MCP tool that allows LLM agents to automatically compare multiple models on the same dataset and select the best one. Features: - Compare multiple estimator handles using cross-validation - Support for MAE, MAPE, MSE, RMSE, SMAPE, MASE, MedAE metrics - Works with both demo datasets and custom data_handle - Returns ranked results with best estimator The tool enables the full agentic loop: extract_ts_metadata() <- know what data looks like instantiate_estimator() x N <- create candidate models compare_estimators() <- NEW: rank all candidates automatically fit_predict(best_handle) <- execute with the winner Implementation: - Added _get_metric_instance() helper to map metric names to sktime classes - Added _extract_metric_value() to extract metrics from evaluate results - Added compare_estimators_tool() function - Registered tool in server.py with full MCP schema
There was a problem hiding this comment.
Pull request overview
Adds agent-facing tooling to (1) compare multiple forecasting estimators via cross-validation for automated model selection and (2) optionally return prediction intervals from fit_predict via a new coverage parameter.
Changes:
- Introduces
compare_estimators_toolwith metric selection and ranked results. - Adds
coveragesupport tofit_predict/fit_predict_asyncand implementsExecutor.predict_interval. - Registers the new tool and updates MCP schemas/descriptions in
server.py.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
| src/sktime_mcp/tools/fit_predict.py | Passes through optional coverage and updates tool docs/examples. |
| src/sktime_mcp/tools/evaluate.py | Adds metric mapping/helpers and implements compare_estimators_tool. |
| src/sktime_mcp/server.py | Registers compare_estimators tool and exposes coverage in schemas. |
| src/sktime_mcp/runtime/executor.py | Adds predict_interval plus coverage-aware fit/predict flows (sync + async). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Extract intervals into a structured format | ||
| intervals_dict = {} | ||
| for col in pred_intervals_copy.columns: | ||
| var_name = col[0] if isinstance(col, tuple) else "predictions" | ||
| cov_level = col[1] if isinstance(col, tuple) else coverage | ||
| bound_type = col[2] if isinstance(col, tuple) else "bound" | ||
|
|
||
| key = f"{var_name}_{cov_level}" | ||
| if key not in intervals_dict: | ||
| intervals_dict[key] = {"coverage": cov_level} | ||
| intervals_dict[key][bound_type] = pred_intervals_copy[col].to_dict() | ||
|
|
||
| return { | ||
| "success": True, | ||
| "predictions": pred_intervals_copy.to_dict(), | ||
| "intervals": intervals_dict, | ||
| "horizon": len(fh) if hasattr(fh, "__len__") else fh, | ||
| "coverage": coverage, | ||
| } |
There was a problem hiding this comment.
predict_interval() currently returns pred_intervals_copy.to_dict() under the predictions key, which is not point forecasts and also changes the output shape compared to predict(). This conflicts with fit_predict_tool's docs/contract that predictions are forecast values and intervals contains bounds. Consider returning point predictions from predict() plus a separate intervals structure, and align the interval key structure with the tool docs (e.g., avoid embedding coverage into the dict key unless documented).
| @@ -41,9 +47,24 @@ def fit_predict_tool( | |||
| "predictions": {1: 450.2, 2: 460.5, ...}, | |||
| "horizon": 12 | |||
| } | |||
|
|
|||
| >>> fit_predict_tool("est_abc123", "airline", horizon=12, coverage=0.9) | |||
| { | |||
| "success": True, | |||
| "predictions": {...}, | |||
| "intervals": {"Airline": {"coverage": 0.9, "lower": {...}, "upper": {...}}}, | |||
| "horizon": 12, | |||
| "coverage": 0.9 | |||
| } | |||
There was a problem hiding this comment.
The fit_predict_tool doc example for intervals shows a structure like {"Airline": {"coverage": 0.9, "lower": ..., "upper": ...}}, but Executor.predict_interval() currently builds keys like "Airline_0.9" and returns a different predictions payload. Please update the docs (or adjust the implementation) so the documented response matches what callers actually receive.
| "default": 12, | ||
| }, | ||
| }, | ||
| "required": ["estimator_handles"], |
There was a problem hiding this comment.
The compare_estimators MCP schema doesn't enforce that exactly one of dataset or data_handle is provided (and not both), even though the description says so and the implementation errors when neither is present. Consider expressing this via JSON Schema (e.g., oneOf/anyOf with required fields + not both) so agents can discover valid calls without trial-and-error.
| "required": ["estimator_handles"], | |
| "required": ["estimator_handles"], | |
| "oneOf": [ | |
| { | |
| "required": ["dataset"], | |
| "not": {"required": ["data_handle"]}, | |
| }, | |
| { | |
| "required": ["data_handle"], | |
| "not": {"required": ["dataset"]}, | |
| }, | |
| ], |
| def compare_estimators_tool( | ||
| estimator_handles: list[str], | ||
| dataset: Optional[str] = None, | ||
| data_handle: Optional[str] = None, | ||
| metric: str = "MAPE", | ||
| cv_folds: int = 3, | ||
| horizon: int = 12, | ||
| ) -> dict[str, Any]: | ||
| """ | ||
| Compare multiple estimators on the same dataset using cross-validation. | ||
|
|
||
| Runs cross-validation on each estimator with the same dataset and metric, | ||
| then returns results ranked by performance. | ||
|
|
||
| Args: | ||
| estimator_handles: List of estimator handles from instantiate_estimator | ||
| dataset: Name of demo dataset (e.g., "airline", "sunspots") | ||
| data_handle: Alternative to dataset - handle from load_data_source for custom data | ||
| metric: Metric to use for comparison (MAE, MAPE, MSE, RMSE, SMAPE, MASE, MedAE) | ||
| cv_folds: Number of cross-validation folds | ||
| horizon: Forecast horizon for evaluation | ||
|
|
||
| Returns: | ||
| Dictionary with: | ||
| - success: bool | ||
| - ranked: List of dicts with rank, handle, estimator name, score | ||
| - best_handle: Handle of the best estimator | ||
| - best_estimator: Name of the best estimator | ||
| - metric: The metric used for comparison | ||
| - cv_folds: Number of folds run | ||
| - horizon: Forecast horizon used | ||
|
|
||
| Example: | ||
| >>> compare_estimators_tool(["est_abc", "est_def"], dataset="airline", metric="MAPE") | ||
| { | ||
| "success": True, | ||
| "ranked": [ | ||
| {"rank": 1, "handle": "est_abc", "estimator": "AutoARIMA", "score": 4.2}, | ||
| {"rank": 2, "handle": "est_def", "estimator": "ARIMA", "score": 7.1} | ||
| ], | ||
| "best_handle": "est_abc", | ||
| "best_estimator": "AutoARIMA", | ||
| "metric": "MAPE", | ||
| "cv_folds": 3, | ||
| "horizon": 12 | ||
| } | ||
| """ | ||
| executor = get_executor() | ||
|
|
||
| # Validate inputs | ||
| if not estimator_handles: | ||
| return {"success": False, "error": "At least one estimator handle is required"} | ||
|
|
||
| if not dataset and not data_handle: | ||
| return { | ||
| "success": False, | ||
| "error": "Either dataset or data_handle must be provided" | ||
| } | ||
|
|
||
| # Load data | ||
| if data_handle: | ||
| if data_handle not in executor._data_handles: | ||
| return { | ||
| "success": False, | ||
| "error": f"Unknown data handle: {data_handle}", | ||
| "available_handles": list(executor._data_handles.keys()), | ||
| } | ||
| data_info = executor._data_handles[data_handle] | ||
| y = data_info["y"] | ||
| X = data_info.get("X") | ||
| else: | ||
| data_result = executor.load_dataset(dataset) | ||
| if not data_result["success"]: | ||
| return data_result | ||
| y = data_result["data"] | ||
| X = data_result.get("exog") | ||
|
|
||
| # Create CV strategy | ||
| try: | ||
| n = len(y) | ||
| initial_window = max(int(n * 0.5), n - cv_folds * 2) | ||
| if initial_window < 1: | ||
| initial_window = 1 | ||
| cv = ExpandingWindowSplitter( | ||
| initial_window=initial_window, | ||
| step_length=1, | ||
| fh=list(range(1, horizon + 1)), | ||
| ) | ||
| except Exception as e: | ||
| return {"success": False, "error": f"Error creating CV splitter: {str(e)}"} | ||
|
|
||
| # Get metric instance | ||
| try: | ||
| scoring = _get_metric_instance(metric) | ||
| except ValueError as e: | ||
| return {"success": False, "error": str(e)} | ||
|
|
||
| # Evaluate each estimator | ||
| results = [] | ||
| errors = [] | ||
|
|
||
| for handle in estimator_handles: | ||
| try: | ||
| # Get estimator instance | ||
| try: | ||
| instance = executor._handle_manager.get_instance(handle) | ||
| except KeyError: | ||
| errors.append({"handle": handle, "error": f"Handle not found: {handle}"}) | ||
| continue | ||
|
|
||
| # Get estimator name for display | ||
| try: | ||
| handle_info = executor._handle_manager.get_info(handle) | ||
| estimator_name = handle_info.estimator_name | ||
| except Exception: | ||
| estimator_name = "Unknown" | ||
|
|
||
| # Run evaluation | ||
| eval_results = evaluate( | ||
| forecaster=instance, | ||
| y=y, | ||
| X=X, | ||
| cv=cv, | ||
| scoring=scoring, | ||
| ) | ||
|
|
||
| # Extract metric value | ||
| score = _extract_metric_value(eval_results, metric) | ||
| if score is None: | ||
| errors.append({ | ||
| "handle": handle, | ||
| "estimator": estimator_name, | ||
| "error": f"Could not extract {metric} from results. Available columns: {list(eval_results.columns)}" | ||
| }) | ||
| continue | ||
|
|
||
| results.append({ | ||
| "handle": handle, | ||
| "estimator": estimator_name, | ||
| "score": score, | ||
| }) | ||
|
|
||
| except Exception as e: | ||
| errors.append({"handle": handle, "error": str(e)}) | ||
|
|
||
| if not results: | ||
| return { | ||
| "success": False, | ||
| "error": "No estimators could be evaluated successfully", | ||
| "errors": errors, | ||
| } | ||
|
|
||
| # Sort by score (ascending - lower is better for most metrics) | ||
| results.sort(key=lambda x: x["score"]) | ||
|
|
||
| # Add ranks | ||
| ranked = [] | ||
| for i, result in enumerate(results, 1): | ||
| ranked.append({ | ||
| "rank": i, | ||
| "handle": result["handle"], | ||
| "estimator": result["estimator"], | ||
| "score": round(result["score"], 4), | ||
| }) | ||
|
|
||
| return { | ||
| "success": True, | ||
| "ranked": ranked, | ||
| "best_handle": ranked[0]["handle"], | ||
| "best_estimator": ranked[0]["estimator"], | ||
| "metric": metric.upper(), | ||
| "cv_folds": cv_folds, | ||
| "horizon": horizon, | ||
| "errors": errors if errors else None, | ||
| } |
There was a problem hiding this comment.
New compare_estimators_tool behavior (metric selection, ranking, data_handle support, error reporting) is not covered by tests. There is already test coverage for evaluate_estimator_tool in tests/test_evaluate.py; adding analogous tests for compare_estimators_tool would help prevent regressions and will also catch issues like metric-column extraction mismatches.
| "coverage": { | ||
| "type": ["number", "array"], | ||
| "description": ( | ||
| "Confidence level(s) for prediction intervals. " | ||
| "Use a single number (e.g., 0.9 for 90% intervals) or " | ||
| "a list of numbers for multiple intervals (e.g., [0.5, 0.9, 0.95]). " | ||
| "Only works with estimators that support prediction intervals " | ||
| "(check capability:pred_int tag)." | ||
| ), | ||
| }, |
There was a problem hiding this comment.
In the MCP JSON schema, coverage allows type "array" but doesn't specify items (e.g., {"type": "number"}). Adding an items schema will prevent agents from passing arrays of non-numbers and makes the tool contract clearer.
| if not hasattr(instance, "predict_interval"): | ||
| return { | ||
| "success": False, | ||
| "error": "Estimator does not support prediction intervals. " | ||
| "Use a probabilistic forecaster (check capability:pred_int tag).", | ||
| } | ||
|
|
||
| pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage) |
There was a problem hiding this comment.
predict_interval() checks support via hasattr(instance, "predict_interval"), but most sktime forecasters inherit this method even when they don't support intervals (capability tag false), so this check can be ineffective and lead to confusing runtime errors. Prefer checking the estimator tag (e.g., get_tag("capability:pred_int")) and/or catching NotImplementedError to return the intended "does not support" error.
| if not hasattr(instance, "predict_interval"): | |
| return { | |
| "success": False, | |
| "error": "Estimator does not support prediction intervals. " | |
| "Use a probabilistic forecaster (check capability:pred_int tag).", | |
| } | |
| pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage) | |
| supports_pred_int = True | |
| if hasattr(instance, "get_tag"): | |
| supports_pred_int = instance.get_tag("capability:pred_int", False) | |
| if not supports_pred_int: | |
| return { | |
| "success": False, | |
| "error": "Estimator does not support prediction intervals. " | |
| "Use a probabilistic forecaster (check capability:pred_int tag).", | |
| } | |
| try: | |
| pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage) | |
| except NotImplementedError: | |
| return { | |
| "success": False, | |
| "error": "Estimator does not support prediction intervals. " | |
| "Use a probabilistic forecaster (check capability:pred_int tag).", | |
| } |
| @@ -217,6 +297,9 @@ def fit_predict( | |||
| if not fit_result["success"]: | |||
| return fit_result | |||
|
|
|||
| # Use predict_interval if coverage is provided, otherwise use predict | |||
| if coverage is not None: | |||
| return self.predict_interval(handle_id, fh=fh, X=X, coverage=coverage) | |||
| return self.predict(handle_id, fh=fh, X=X) | |||
There was a problem hiding this comment.
The new coverage/prediction-interval path (Executor.predict_interval + fit_predict(..., coverage=...)) isn't exercised by existing tests. Adding a test that instantiates a forecaster with capability:pred_int and asserts the returned JSON structure (point predictions + lower/upper bounds for the requested coverage) would help ensure this API remains stable.
| "MSE": MeanSquaredError, | ||
| "RMSE": lambda: MeanSquaredError(square_root=True), | ||
| "MAPE": MeanAbsolutePercentageError, | ||
| "SMAPE": GeometricMeanAbsolutePercentageError, |
There was a problem hiding this comment.
The metric mapping for "SMAPE" points to GeometricMeanAbsolutePercentageError, which is a different metric (gMAPE) than symmetric MAPE. This will rank estimators using the wrong objective whenever metric="SMAPE". Please switch to the appropriate sMAPE implementation (e.g., a symmetric MAPE metric class/parameter in sktime) and keep the abbreviation-to-metric mapping consistent with the tool's documented supported metrics.
| "SMAPE": GeometricMeanAbsolutePercentageError, | |
| "SMAPE": lambda: MeanAbsolutePercentageError(symmetric=True), |
| try: | ||
| n = len(y) | ||
| initial_window = max(int(n * 0.5), n - cv_folds * 2) | ||
| if initial_window < 1: | ||
| initial_window = 1 | ||
| cv = ExpandingWindowSplitter( | ||
| initial_window=initial_window, | ||
| step_length=1, | ||
| fh=list(range(1, horizon + 1)), | ||
| ) |
There was a problem hiding this comment.
The cv_folds parameter is not actually used to control how many folds ExpandingWindowSplitter produces, especially once fh spans horizon steps. The current initial_window = max(int(n*0.5), n - cv_folds * 2) can yield a number of splits that differs from cv_folds and can also violate the requirement that there is enough data left for the forecasting horizon. Consider computing initial_window/step_length based on horizon and cv_folds (or explicitly truncating to the first cv_folds splits) so the tool reliably runs the requested number of folds.
|
|
||
| import logging | ||
| from typing import Any | ||
| from typing import Any, Optional, Union |
There was a problem hiding this comment.
Union is imported but not used in this module. Dropping unused imports keeps linting clean and avoids confusion about intended types.
| from typing import Any, Optional, Union | |
| from typing import Any, Optional |
Summary
Implements issue #178 - a new MCP tool that allows LLM agents to automatically compare multiple models on the same dataset and select the best one.
Problem
Currently, when an LLM agent uses to get model recommendations, it still has to blindly pick one estimator from the list. There is no way for the agent to automatically run multiple models on the same dataset, compare their performance, and select the best one.
Solution
A new tool that:
Agentic Loop
This tool enables the full autonomous workflow:
API
Input:
{ "estimator_handles": ["est_abc123", "est_def456", "est_ghi789"], "dataset": "airline", "metric": "MAPE", "horizon": 12 }Output:
{ "success": true, "ranked": [ {"rank": 1, "handle": "est_abc123", "estimator": "AutoARIMA", "score": 4.2}, {"rank": 2, "handle": "est_ghi789", "estimator": "ExponentialSmoothing", "score": 5.8}, {"rank": 3, "handle": "est_def456", "estimator": "ARIMA", "score": 7.1} ], "best_handle": "est_abc123", "best_estimator": "AutoARIMA", "metric": "MAPE" }Supported Metrics
Implementation