-
Notifications
You must be signed in to change notification settings - Fork 99
feat: add compare_estimators tool for agentic model selection #236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,14 +184,94 @@ def predict( | |
| except Exception as e: | ||
| return {"success": False, "error": str(e)} | ||
|
|
||
| def predict_interval( | ||
| self, | ||
| handle_id: str, | ||
| fh: Optional[Union[int, list[int]]] = None, | ||
| X: Optional[Any] = None, | ||
| coverage: Union[float, list[float]] = 0.9, | ||
| ) -> dict[str, Any]: | ||
| """Generate predictions with prediction intervals. | ||
|
|
||
| Args: | ||
| handle_id: Estimator handle | ||
| fh: Forecast horizon | ||
| X: Optional exogenous variables | ||
| coverage: Confidence level(s) for prediction intervals. | ||
| Can be a single float (e.g., 0.9 for 90% interval) or | ||
| a list of floats for multiple intervals. | ||
|
|
||
| Returns: | ||
| Dictionary with success status, predictions, and intervals | ||
| """ | ||
| try: | ||
| instance = self._handle_manager.get_instance(handle_id) | ||
| except KeyError: | ||
| return {"success": False, "error": f"Handle not found: {handle_id}"} | ||
|
|
||
| if not self._handle_manager.is_fitted(handle_id): | ||
| return {"success": False, "error": "Estimator not fitted"} | ||
|
|
||
| try: | ||
| if fh is None: | ||
| fh = list(range(1, 13)) | ||
|
|
||
| # Check if estimator supports prediction intervals | ||
| if not hasattr(instance, "predict_interval"): | ||
| return { | ||
| "success": False, | ||
| "error": "Estimator does not support prediction intervals. " | ||
| "Use a probabilistic forecaster (check capability:pred_int tag).", | ||
| } | ||
|
|
||
| pred_intervals = instance.predict_interval(fh=fh, X=X, coverage=coverage) | ||
|
|
||
| # Convert to JSON-serializable format | ||
| # The result is a DataFrame with MultiIndex columns: (variable, coverage, lower/upper) | ||
| pred_intervals_copy = pred_intervals.copy() | ||
| pred_intervals_copy.index = pred_intervals_copy.index.astype(str) | ||
|
|
||
| # Extract intervals into a structured format | ||
| intervals_dict = {} | ||
| for col in pred_intervals_copy.columns: | ||
| var_name = col[0] if isinstance(col, tuple) else "predictions" | ||
| cov_level = col[1] if isinstance(col, tuple) else coverage | ||
| bound_type = col[2] if isinstance(col, tuple) else "bound" | ||
|
|
||
| key = f"{var_name}_{cov_level}" | ||
| if key not in intervals_dict: | ||
| intervals_dict[key] = {"coverage": cov_level} | ||
| intervals_dict[key][bound_type] = pred_intervals_copy[col].to_dict() | ||
|
|
||
| return { | ||
| "success": True, | ||
| "predictions": pred_intervals_copy.to_dict(), | ||
| "intervals": intervals_dict, | ||
| "horizon": len(fh) if hasattr(fh, "__len__") else fh, | ||
| "coverage": coverage, | ||
| } | ||
|
Comment on lines
+234
to
+252
|
||
| except Exception as e: | ||
| return {"success": False, "error": str(e)} | ||
|
|
||
| def fit_predict( | ||
| self, | ||
| handle_id: str, | ||
| dataset: str, | ||
| horizon: int = 12, | ||
| data_handle: Optional[str] = None, | ||
| coverage: Optional[Union[float, list[float]]] = None, | ||
| ) -> dict[str, Any]: | ||
| """Convenience method: load data, fit, and predict.""" | ||
| """Convenience method: load data, fit, and predict. | ||
|
|
||
| Args: | ||
| handle_id: Estimator handle | ||
| dataset: Dataset name | ||
| horizon: Forecast horizon | ||
| data_handle: Optional data handle for custom data | ||
| coverage: Optional confidence level(s) for prediction intervals. | ||
| If provided, uses predict_interval() instead of predict(). | ||
| Can be a single float (e.g., 0.9) or list of floats. | ||
| """ | ||
| if data_handle is not None: | ||
| # Use custom loaded data | ||
| if data_handle not in self._data_handles: | ||
|
|
@@ -217,6 +297,9 @@ def fit_predict( | |
| if not fit_result["success"]: | ||
| return fit_result | ||
|
|
||
| # Use predict_interval if coverage is provided, otherwise use predict | ||
| if coverage is not None: | ||
| return self.predict_interval(handle_id, fh=fh, X=X, coverage=coverage) | ||
| return self.predict(handle_id, fh=fh, X=X) | ||
|
Comment on lines
187
to
303
|
||
|
|
||
| async def fit_predict_async( | ||
|
|
@@ -225,6 +308,7 @@ async def fit_predict_async( | |
| dataset: str, | ||
| horizon: int = 12, | ||
| job_id: Optional[str] = None, | ||
| coverage: Optional[Union[float, list[float]]] = None, | ||
| ) -> dict[str, Any]: | ||
| """ | ||
| Async version of fit_predict with job tracking. | ||
|
|
@@ -237,6 +321,8 @@ async def fit_predict_async( | |
| dataset: Dataset name | ||
| horizon: Forecast horizon | ||
| job_id: Optional job ID for tracking (created if not provided) | ||
| coverage: Optional confidence level(s) for prediction intervals. | ||
| If provided, uses predict_interval() instead of predict(). | ||
|
|
||
| Returns: | ||
| Dictionary with success status and job_id | ||
|
|
@@ -304,17 +390,25 @@ async def fit_predict_async( | |
| return fit_result | ||
|
|
||
| # Step 3: Generate predictions | ||
| step_msg = f"Generating predictions (horizon={horizon})" | ||
| if coverage is not None: | ||
| step_msg += f" with coverage={coverage}" | ||
| self._job_manager.update_job( | ||
| job_id, | ||
| completed_steps=2, | ||
| current_step=f"Generating predictions (horizon={horizon})...", | ||
| current_step=step_msg, | ||
| ) | ||
| await asyncio.sleep(0.01) # Yield control | ||
|
|
||
| # Run predict in executor | ||
| predict_result = await loop.run_in_executor( | ||
| None, lambda: self.predict(handle_id, fh=fh, X=X) | ||
| ) | ||
| # Run predict or predict_interval in executor | ||
| if coverage is not None: | ||
| predict_result = await loop.run_in_executor( | ||
| None, lambda: self.predict_interval(handle_id, fh=fh, X=X, coverage=coverage) | ||
| ) | ||
| else: | ||
| predict_result = await loop.run_in_executor( | ||
| None, lambda: self.predict(handle_id, fh=fh, X=X) | ||
| ) | ||
|
|
||
| if not predict_result["success"]: | ||
| self._job_manager.update_job( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,7 +24,7 @@ | |||||||||||||||||||||||||
| release_data_handle_tool, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| from sktime_mcp.tools.describe_estimator import describe_estimator_tool | ||||||||||||||||||||||||||
| from sktime_mcp.tools.evaluate import evaluate_estimator_tool | ||||||||||||||||||||||||||
| from sktime_mcp.tools.evaluate import compare_estimators_tool, evaluate_estimator_tool | ||||||||||||||||||||||||||
| from sktime_mcp.tools.fit_predict import ( | ||||||||||||||||||||||||||
| fit_predict_async_tool, | ||||||||||||||||||||||||||
| fit_predict_tool, | ||||||||||||||||||||||||||
|
|
@@ -244,7 +244,8 @@ async def list_tools() -> list[Tool]: | |||||||||||||||||||||||||
| name="fit_predict", | ||||||||||||||||||||||||||
| description=( | ||||||||||||||||||||||||||
| "Fit an estimator on a dataset and generate predictions. " | ||||||||||||||||||||||||||
| "Accepts either a demo dataset name or a data_handle from load_data_source." | ||||||||||||||||||||||||||
| "Accepts either a demo dataset name or a data_handle from load_data_source. " | ||||||||||||||||||||||||||
| "Use the coverage parameter to get prediction intervals." | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| inputSchema={ | ||||||||||||||||||||||||||
| "type": "object", | ||||||||||||||||||||||||||
|
|
@@ -269,6 +270,16 @@ async def list_tools() -> list[Tool]: | |||||||||||||||||||||||||
| "description": "Forecast horizon (default: 12)", | ||||||||||||||||||||||||||
| "default": 12, | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "coverage": { | ||||||||||||||||||||||||||
| "type": ["number", "array"], | ||||||||||||||||||||||||||
| "description": ( | ||||||||||||||||||||||||||
| "Confidence level(s) for prediction intervals. " | ||||||||||||||||||||||||||
| "Use a single number (e.g., 0.9 for 90% intervals) or " | ||||||||||||||||||||||||||
| "a list of numbers for multiple intervals (e.g., [0.5, 0.9, 0.95]). " | ||||||||||||||||||||||||||
| "Only works with estimators that support prediction intervals " | ||||||||||||||||||||||||||
| "(check capability:pred_int tag)." | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
|
Comment on lines
+273
to
+282
|
||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "required": ["estimator_handle"], | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
|
|
@@ -277,7 +288,8 @@ async def list_tools() -> list[Tool]: | |||||||||||||||||||||||||
| name="fit_predict_async", | ||||||||||||||||||||||||||
| description=( | ||||||||||||||||||||||||||
| "Fit an estimator on a dataset and generate predictions " | ||||||||||||||||||||||||||
| "(non-blocking background job). Returns a job_id." | ||||||||||||||||||||||||||
| "(non-blocking background job). Returns a job_id. " | ||||||||||||||||||||||||||
| "Use the coverage parameter to get prediction intervals." | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| inputSchema={ | ||||||||||||||||||||||||||
| "type": "object", | ||||||||||||||||||||||||||
|
|
@@ -295,6 +307,16 @@ async def list_tools() -> list[Tool]: | |||||||||||||||||||||||||
| "description": "Forecast horizon (default: 12)", | ||||||||||||||||||||||||||
| "default": 12, | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "coverage": { | ||||||||||||||||||||||||||
| "type": ["number", "array"], | ||||||||||||||||||||||||||
| "description": ( | ||||||||||||||||||||||||||
| "Confidence level(s) for prediction intervals. " | ||||||||||||||||||||||||||
| "Use a single number (e.g., 0.9 for 90% intervals) or " | ||||||||||||||||||||||||||
| "a list of numbers for multiple intervals (e.g., [0.5, 0.9, 0.95]). " | ||||||||||||||||||||||||||
| "Only works with estimators that support prediction intervals " | ||||||||||||||||||||||||||
| "(check capability:pred_int tag)." | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "required": ["estimator_handle", "dataset"], | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
|
|
@@ -322,6 +344,59 @@ async def list_tools() -> list[Tool]: | |||||||||||||||||||||||||
| "required": ["estimator_handle", "dataset"], | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| Tool( | ||||||||||||||||||||||||||
| name="compare_estimators", | ||||||||||||||||||||||||||
| description=( | ||||||||||||||||||||||||||
| "Compare multiple estimators on the same dataset using cross-validation. " | ||||||||||||||||||||||||||
| "Runs CV on each estimator with the same dataset and metric, " | ||||||||||||||||||||||||||
| "then returns results ranked by performance. " | ||||||||||||||||||||||||||
| "Use this for agentic model selection - run multiple candidates " | ||||||||||||||||||||||||||
| "and automatically pick the best one." | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| inputSchema={ | ||||||||||||||||||||||||||
| "type": "object", | ||||||||||||||||||||||||||
| "properties": { | ||||||||||||||||||||||||||
| "estimator_handles": { | ||||||||||||||||||||||||||
| "type": "array", | ||||||||||||||||||||||||||
| "items": {"type": "string"}, | ||||||||||||||||||||||||||
| "description": "List of estimator handles from instantiate_estimator", | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "dataset": { | ||||||||||||||||||||||||||
| "type": "string", | ||||||||||||||||||||||||||
| "description": ( | ||||||||||||||||||||||||||
| "Dataset name: airline, sunspots, lynx, etc. " | ||||||||||||||||||||||||||
| "Use this or data_handle, not both." | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "data_handle": { | ||||||||||||||||||||||||||
| "type": "string", | ||||||||||||||||||||||||||
| "description": ( | ||||||||||||||||||||||||||
| "Handle from load_data_source for custom data. " | ||||||||||||||||||||||||||
| "Use this or dataset, not both." | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "metric": { | ||||||||||||||||||||||||||
| "type": "string", | ||||||||||||||||||||||||||
| "description": ( | ||||||||||||||||||||||||||
| "Metric to use for comparison: MAE, MAPE, MSE, RMSE, SMAPE, MASE, MedAE " | ||||||||||||||||||||||||||
| "(default: MAPE)" | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| "default": "MAPE", | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "cv_folds": { | ||||||||||||||||||||||||||
| "type": "integer", | ||||||||||||||||||||||||||
| "description": "Number of cross-validation folds (default: 3)", | ||||||||||||||||||||||||||
| "default": 3, | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "horizon": { | ||||||||||||||||||||||||||
| "type": "integer", | ||||||||||||||||||||||||||
| "description": "Forecast horizon for evaluation (default: 12)", | ||||||||||||||||||||||||||
| "default": 12, | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| "required": ["estimator_handles"], | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
| "required": ["estimator_handles"], | |
| "required": ["estimator_handles"], | |
| "oneOf": [ | |
| { | |
| "required": ["dataset"], | |
| "not": {"required": ["data_handle"]}, | |
| }, | |
| { | |
| "required": ["data_handle"], | |
| "not": {"required": ["dataset"]}, | |
| }, | |
| ], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predict_interval() checks support via
hasattr(instance, "predict_interval"), but most sktime forecasters inherit this method even when they don't support intervals (capability tag false), so this check can be ineffective and lead to confusing runtime errors. Prefer checking the estimator tag (e.g., get_tag("capability:pred_int")) and/or catching NotImplementedError to return the intended "does not support" error.