Skip to content

refactor(skore): Make pos_label more consistent#2663

Open
cakedev0 wants to merge 21 commits intomainfrom
pos_label_consistency
Open

refactor(skore): Make pos_label more consistent#2663
cakedev0 wants to merge 21 commits intomainfrom
pos_label_consistency

Conversation

@cakedev0
Copy link
Copy Markdown
Contributor

@cakedev0 cakedev0 commented Mar 25, 2026

Towards #2642 and #2592

Change description

Go from behavior table described in #2592 to this:

labels pos_label precision/recall roc_auc brier_score roc/PR/conf. matrix
[0,1] / ["A","B"] unset all classes scalar scalar all classes
[0,1] / ["A","B"] fixed (1/"A") pos_label=1/"A" scalar scalar pos_label=1/"A"
[0,1,2] unset all classes all classes attribute error all classes
[0,1,2] fixed (2) error at init error at init error at init error at init

This PR also refactors how predictions are made to make those changes easier and too prepare the ground for moving pos_label from the init to the arguments of metrics/plots. Indeed, now cached predictions don't depend on the pos_label anymore (we adapt for pos_label on the fly if needed).
By doing so, it fixes #2671 (with the option 1 described in the issue)
And this refactor also fixes #2672

Contribution checklist

  • Unit tests were added or updated
  • Documentation was added or updated. TODO: check I updated everything that was needed.
  • TODO? A new changelog entry was added to CHANGELOG.rst

AI usage disclosure

For review and tests; not for the core changes, those were a bit too delicate* for AI.

*too delicate = not well-specified enough when I started working on it 😅

@cakedev0 cakedev0 marked this pull request as draft March 25, 2026 09:08
@cakedev0 cakedev0 changed the title refactor(skore): Make pos_ label more consistent. refactor(skore): Make pos_ label more consistent Mar 25, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 25, 2026

Documentation preview @ 9d25578

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 25, 2026

Coverage

Coverage Report for skore/
FileStmtsMissCoverMissing
skore/src/skore
   __init__.py280100% 
   _config.py44197%57
   exceptions.py440%4, 15, 19, 23
skore/src/skore/_project
   __init__.py00100% 
   _summary.py80198%121
   _widget.py1950100% 
   login.py13284%65–66
   plugin.py120100% 
   project.py54296%131, 140
   types.py30100% 
skore/src/skore/_sklearn
   __init__.py80100% 
   _base.py380100% 
   compare.py50100% 
   evaluate.py27196%133
   feature_names.py260100% 
   find_ml_task.py610100% 
   types.py21195%29
skore/src/skore/_sklearn/_comparison
   __init__.py70100% 
   inspection_accessor.py26196%352
   metrics_accessor.py97396%214, 822, 989
   report.py123794%479, 486, 489, 495, 536–538
skore/src/skore/_sklearn/_cross_validation
   __init__.py90100% 
   data_accessor.py36294%48, 74
   inspection_accessor.py26196%324
   metrics_accessor.py93297%821, 965
   report.py122695%430, 433, 439, 504–506
skore/src/skore/_sklearn/_estimator
   __init__.py90100% 
   data_accessor.py48197%177
   inspection_accessor.py340100% 
   metrics_accessor.py262299%310, 1106
   report.py2101393%209, 418, 452, 458, 509, 512, 531, 533, 539–540, 607–609
skore/src/skore/_sklearn/_plot
   __init__.py30100% 
   base.py61296%61–62
   utils.py149795%63, 65–66, 68, 274–275, 454
skore/src/skore/_sklearn/_plot/data
   __init__.py20100% 
   table_report.py177199%670
skore/src/skore/_sklearn/_plot/inspection
   __init__.py00100% 
   coefficients.py1810100% 
   impurity_decrease.py103298%423, 467
   permutation_importance.py196199%583
   utils.py320100% 
skore/src/skore/_sklearn/_plot/metrics
   __init__.py60100% 
   confusion_matrix.py1650100% 
   metrics_summary_display.py1000100% 
   precision_recall_curve.py1050100% 
   prediction_error.py1660100% 
   roc_curve.py1100100% 
skore/src/skore/_sklearn/train_test_split
   __init__.py20100% 
   train_test_split.py710100% 
skore/src/skore/_sklearn/train_test_split/warning
   __init__.py80100% 
   high_class_imbalance_too_few_examples_warning.py19194%83
   high_class_imbalance_warning.py200100% 
   random_state_unset_warning.py100100% 
   shuffle_true_warning.py90100% 
   stratify_is_set_warning.py100100% 
   time_based_column_warning.py210100% 
   train_test_split_warning.py30100% 
skore/src/skore/_utils
   __init__.py6266%8, 13
   _accessor.py106793%36, 92–94, 164, 218, 238
   _cache.py370100% 
   _cache_key.py35585%22, 24, 51, 59, 68
   _dataframe.py37197%56
   _environment.py32196%44
   _fixes.py80100% 
   _index.py50100% 
   _jupyter.py8275%13–14
   _logger.py22481%15–17, 19
   _measure_time.py100100% 
   _parallel.py170100% 
   _patch.py211242%30, 35–39, 42–43, 46–47, 58, 60
   _progress_bar.py41490%55–56, 66–67
   _show_versions.py380100% 
   _testing.py1121190%23, 32, 160, 169, 180–185, 187
skore/src/skore/_utils/repr
   __init__.py20100% 
   base.py540100% 
   data.py1620100% 
   html_repr.py400100% 
   rich_repr.py810100% 
TOTAL432411397% 

Tests Skipped Failures Errors Time
2046 5 💤 0 ❌ 0 🔥 8m 16s ⏱️

@cakedev0
Copy link
Copy Markdown
Contributor Author

Tests for 8e22d89 pass, but when pos_label=None (in binary-classif) you get really bad plots 😅

image

@glemaitre glemaitre changed the title refactor(skore): Make pos_ label more consistent refactor(skore): Make pos_label more consistent Mar 25, 2026
@cakedev0
Copy link
Copy Markdown
Contributor Author

Looking good now (fix: e130809), but I had to rewrite mannnny tests 😅 (9ccb49d)

image

@cakedev0 cakedev0 marked this pull request as ready for review March 26, 2026 09:16
Copy link
Copy Markdown
Collaborator

@GaetandeCast GaetandeCast left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks @cakedev0! One suggestion to fix #2671. Also I did not follow the whole discussion but it looks like we don't infer the pos_label anymore?

@@ -215,6 +235,7 @@ def clear_cache(self) -> None:
def cache_predictions(
self,
response_methods: Literal["auto"] | str | list[str] = "auto",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we could do to solve #2671 is to remove the response_methods here and cache all prediction methods in cache_prediction. This means we predict with either predict_proba or decision_function when available and deduce the predictions, or compute them with predict otherwise. This way we only predict once with the most informative function and can access any type of prediction in the cache later.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I 100% love this idea. This is a very valuable optimization I think (basically 2x/3x speed-up for predictions-dominated models).

Do you think we can decide to go with that? Or should we gather more opinions first?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can go with it. I don't really see any drawbacks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum... DummyClassifier(strategy="uniform") the argmax of predict proba is not predict (predictions are random, predict proba is 1/n_classes everywhere), this make a bunch of tests break.

More realistic examples are:

  • SVC(probability=True): but for this one, we can use the decision function
  • FixedThresholdClassifier and TunedThresholdClassifierCV (those probably break assumptions made in other places too).

I keep running on such small but assumption/abstraction-breaking edge cases those days 😭

Anyway, what do we do?

  1. Ignore that and change the tests to avoid DummyClassifier(strategy="uniform"), but maybe this is on purpose that we have this in the tests.
  2. Implement a special treatment for DummyClassifier (and maybe FixedThresholdClassifier, TunedThresholdClassifierCV), and use decision_function first when available (this fixes the SVC(probability=True) case)
  3. Give up on this nice optimization (which also simplifies the code quite a lot honnestly)

I vote for 2 ^^ but I'd like to have @glemaitre opinion on that.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can go for the optimization. In short scikit-learn has common test to test this assumption, so most of the classifiers in scikit-learn would benefit for it. However, we cannot provide the same optimization for estimator outside of scikit-learn because we are not aware. In an ideal world, maybe the system of tags could help in this direction but we cannot bet that people implement it.

So in short, we can implement it in another PR where we will have a fallback for *ThresholdClassifier* and non-scikit-learn estimator.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. I opened a draft PR with the refactor/optim: #2677.

In this PR, I went with option 1 for #2671 (i.e. record only the time for predict, and not for other response methods).

@cakedev0
Copy link
Copy Markdown
Contributor Author

we don't infer the pos_label anymore?

Indeed. Inferring it only for the case {0, 1} and {-1, 1} and crashing otherwise is not a great behavior. Instead we choose to not crash and display/return for both labels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: Check for decision function shape (OVR vs OVO) Bug: predict time is not stable across calls

3 participants