Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 2 additions & 93 deletions skore/src/skore/_sklearn/_base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
from functools import cached_property
from io import StringIO
from typing import Any, Generic, Literal, TypeVar, cast
from typing import Generic, Literal, TypeVar
from uuid import uuid4

from numpy.typing import ArrayLike, NDArray
from numpy.typing import ArrayLike
from rich.console import Console
from rich.panel import Panel
from sklearn.base import BaseEstimator
from sklearn.utils._response import _check_response_method, _get_response_values

from skore._sklearn.types import PositiveLabel
from skore._utils._cache import Cache
from skore._utils._cache_key import make_cache_key
from skore._utils._measure_time import MeasureTime
from skore._utils.repr.base import AccessorHelpMixin, ReportHelpMixin


Expand Down Expand Up @@ -114,88 +108,3 @@ def _get_X_y(
raise ValueError(
f"Invalid data source: {data_source}. Possible values are: test, train."
)


def _get_cached_response_values(
*,
cache: Cache,
estimator: BaseEstimator,
X: ArrayLike | None,
response_method: str | list[str] | tuple[str, ...],
pos_label: PositiveLabel | None = None,
data_source: Literal["test", "train"] = "test",
) -> list[tuple[tuple[Any, ...], Any, bool]]:
"""Compute or load from local cache the response values.

Be aware that the predictions will be loaded from the cache if present, but they
will not be added to it. The reason is that we want to be able to run this function
in parallel settings in a thread-safe manner. The update should be done outside of
this function.

Parameters
----------
cache : Cache
The cache backend to use.

estimator : estimator object
The estimator used to generate the predictions.

X : {array-like, sparse matrix} of shape (n_samples, n_features) or None
The input data on which to compute the responses when needed.

response_method : str, list of str or tuple of str
The response method.

pos_label : int, float, bool or str, default=None
The positive label.

data_source : {"test", "train"}, default="test"
The data source to use.

- "test" : use the test set provided when creating the report.
- "train" : use the train set provided when creating the report.

Returns
-------
list of tuples
A list of tuples, each containing:

- cache_key : tuple
The cache key.

- cache_value : Any
The cache value. It corresponds to the predictions but also to the predict
time when it has not been cached yet.

- is_cached : bool
Whether the cache value was loaded from the cache.
"""
prediction_method = _check_response_method(estimator, response_method).__name__

if prediction_method not in ("predict_proba", "decision_function"):
# pos_label is only important in classification and with probabilities
# and decision functions
pos_label = None

kwargs = {"pos_label": pos_label}
cache_key = make_cache_key(data_source, prediction_method, kwargs)

if cache_key in cache:
cached_predictions = cast(NDArray, cache[cache_key])
return [(cache_key, cached_predictions, True)]

with MeasureTime() as predict_time:
predictions, _ = _get_response_values(
estimator,
X=X,
response_method=prediction_method,
pos_label=pos_label,
return_response_method_used=False,
)

predict_time_cache_key = make_cache_key(data_source, "predict_time")

return [
(cache_key, predictions, False),
(predict_time_cache_key, predict_time(), False),
]
7 changes: 4 additions & 3 deletions skore/src/skore/_sklearn/_cross_validation/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ class CrossValidationReport(_BaseReport, DirNamesMixin):
The target variable to try to predict in the case of supervised learning.

pos_label : int, float, bool or str, default=None
For binary classification, the positive class. If `None` and the target labels
are `{0, 1}` or `{-1, 1}`, the positive class is set to `1`. For other labels,
some metrics might raise an error if `pos_label` is not defined.
For binary classification, the positive class to use for metrics and displays
that need one. If `None`, skore does not infer a default positive class.
Binary metrics and displays that support it will expose all classes instead.
This parameter is rejected for non-binary tasks.

splitter : int, cross-validation generator or an iterable, default=5
Determines the cross-validation splitting strategy.
Expand Down
67 changes: 37 additions & 30 deletions skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sklearn.utils.metaestimators import available_if

from skore._externals._pandas_accessors import DirNamesMixin
from skore._sklearn._base import _BaseAccessor, _get_cached_response_values
from skore._sklearn._base import _BaseAccessor
from skore._sklearn._estimator.report import EstimatorReport
from skore._sklearn._plot import (
ConfusionMatrixDisplay,
Expand Down Expand Up @@ -371,29 +371,24 @@ def _compute_metric_scores(
*,
response_method: str | list[str] | tuple[str, ...],
data_source: DataSource = "test",
prediction_pos_label: PositiveLabel | None = None,
**metric_kwargs: Any,
) -> float | dict[PositiveLabel, float] | list:
X, y_true = self._get_X_y(data_source=data_source)

pos_label = self._parent.pos_label
if prediction_pos_label is None:
prediction_pos_label = pos_label

cache_key = make_cache_key(data_source, metric_fn.__name__, metric_kwargs)

score = self._parent._cache.get(cache_key)
if score is None:
results = _get_cached_response_values(
cache=self._parent._cache,
estimator=self._parent.estimator_,
X=X,
response_method=response_method,
pos_label=pos_label,
y_pred = self._parent._get_predictions(
data_source=data_source,
response_method=response_method,
pos_label=prediction_pos_label,
)
for key_tuple, value, is_cached in results:
if not is_cached:
self._parent._cache[key_tuple] = value
if key_tuple[1] != "predict_time":
y_pred = value

metric_params = inspect.signature(metric_fn).parameters
kwargs = {**metric_kwargs}
Expand Down Expand Up @@ -463,8 +458,9 @@ def timings(self) -> dict:
"""Get all measured processing times related to the estimator.

When an estimator is fitted inside the :class:`~skore.EstimatorReport`, the time
to fit is recorded. Similarly, when predictions are computed on some data, the
time to predict is recorded. This function returns all the recorded times.
to fit is recorded. Prediction time is recorded when the estimator's
`predict` method is computed and cached for a given data source. This function
returns all the recorded times.

Returns
-------
Expand Down Expand Up @@ -751,13 +747,15 @@ def brier_score(
"""
# The Brier score in scikit-learn request `pos_label` to ensure that the
# integral encoding of `y_true` corresponds to the probabilities of the
# `pos_label`. Since we get the predictions with `get_response_method`, we
# can pass any `pos_label`, they will lead to the same result.
# `pos_label`. We make sure to pass the same `pos_label` to `_get_predictions`
# than to the metric.
pos_label = self._parent.estimator_.classes_[-1]
result = self._compute_metric_scores(
sklearn.metrics.brier_score_loss,
data_source=data_source,
response_method="predict_proba",
pos_label=self._parent._estimator.classes_[-1],
pos_label=pos_label,
prediction_pos_label=pos_label,
)
return cast(float, result)

Expand Down Expand Up @@ -835,14 +833,17 @@ def roc_auc(
>>> report.metrics.roc_auc()
0.99...
"""
is_multiclass = self._parent._ml_task == "multiclass-classification"
pred_pos_label = None if is_multiclass else self._parent.estimator_.classes_[-1]
result = self._compute_metric_scores(
sklearn.metrics.roc_auc_score,
data_source=data_source,
response_method=["predict_proba", "decision_function"],
prediction_pos_label=pred_pos_label,
average=average,
multi_class=multi_class,
)
if self._parent._ml_task == "multiclass-classification" and average is None:
if is_multiclass and average is None:
return cast(dict[PositiveLabel, float], result)
return cast(float, result)

Expand Down Expand Up @@ -1120,6 +1121,7 @@ def _get_display(
| ConfusionMatrixDisplay
],
display_kwargs: dict[str, Any],
prediction_pos_label=None,
) -> (
RocCurveDisplay
| PrecisionRecallCurveDisplay
Expand Down Expand Up @@ -1181,20 +1183,14 @@ def _get_display(

data_source = cast(DataSource, data_source)
X, y_true = self._get_X_y(data_source=data_source)
if prediction_pos_label is None:
prediction_pos_label = self._parent.pos_label

results = _get_cached_response_values(
cache=self._parent._cache,
estimator=self._parent.estimator_,
X=X,
response_method=response_method,
pos_label=display_kwargs.get("pos_label"),
y_pred = self._parent._get_predictions(
data_source=data_source,
response_method=response_method,
pos_label=prediction_pos_label,
)
for key, value, is_cached in results:
if not is_cached:
self._parent._cache[key] = value
if key[1] != "predict_time":
y_pred = value

display = display_class._compute_data_for_display(
y_true=y_true,
Expand Down Expand Up @@ -1426,14 +1422,24 @@ def confusion_matrix(
>>> display = report.metrics.confusion_matrix()
>>> display.plot(threshold_value=0.7)
"""
if data_source == "both":
raise ValueError(
"data_source='both' is not supported for confusion_matrix."
)

response_method: str | list[str] | tuple[str, ...]
pos_label = self._parent.pos_label
pred_pos_label: PositiveLabel | None
if self._parent._ml_task == "binary-classification":
response_method = ("predict_proba", "decision_function")
pred_pos_label = (
self._parent.estimator_.classes_[-1] if pos_label is None else pos_label
)
else:
response_method = "predict"
pred_pos_label = None

display_kwargs = {
"display_labels": tuple(self._parent.estimator_.classes_),
"pos_label": self._parent.pos_label,
"response_method": response_method,
}
Expand All @@ -1444,6 +1450,7 @@ def confusion_matrix(
response_method=response_method,
display_class=ConfusionMatrixDisplay,
display_kwargs=display_kwargs,
prediction_pos_label=pred_pos_label,
),
)
return display
Loading
Loading