Skip to content

Commit acd36ba

Browse files
committed
Fix MSR notebook after upgrade to torch 2.8
1 parent 5e8e723 commit acd36ba

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

notebooks/msr_banzhaf_digits.ipynb

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
"\n",
1313
"Additionally, we compare two sampling techniques: the standard permutation-based Monte Carlo sampling, and the so-called MSR (Maximum Sample Reuse) principle.\n",
1414
"\n",
15-
"In order to highlight the strengths of Data-Banzhaf, we require a stochastic model. For this reason, we use a CNN to classify handwritten digits from the [scikit-learn toy datasets](https://scikit-learn.org/stable/datasets/toy_dataset.html#optical-recognition-of-handwritten-digits-dataset)."
15+
"In order to highlight the strengths of Data-Banzhaf, we require a stochastic model. For this reason, we use a CNN to classify handwritten MNIST digits from the [scikit-learn toy datasets](https://scikit-learn.org/stable/datasets/toy_dataset.html#optical-recognition-of-handwritten-digits-dataset).\n",
16+
"\n",
17+
"To showcase the use of pytorch with valuation methods, the network is a [skorch](https://github.com/skorch-dev/skorch) model."
1618
]
1719
},
1820
{
@@ -128,7 +130,9 @@
128130
"source": [
129131
"## Creating the utility and computing Banzhaf semi-values\n",
130132
"\n",
131-
"Now we can calculate the contribution of each training sample to the model performance. We use a simple CNN written in torch and wrap it into a [skorch.classifier.NeuralNetClassifier][]. Note that any model that implements the protocol [SupervisedModel][pydvl.utils.types.SupervisedModel], which is just the standard sklearn interface of `fit()`,`predict()` and `score()` can be used to construct the utility, so it is possible to construct your own wrapper. Nevertheless, skorch conveniently implements the full sklearn interface, allowing e.g. the use of torch models in pipelines."
133+
"Now we can calculate the contribution of each training sample to the model performance. We use a simple CNN written in torch and wrap it into a [skorch.classifier.NeuralNetClassifier][]. Note that any model that implements the protocol [SupervisedModel][pydvl.utils.types.SupervisedModel], which is just the standard sklearn interface of `fit()`,`predict()` and `score()` can be used to construct the utility, so it is possible to construct your own wrapper. Nevertheless, skorch conveniently implements the full sklearn interface, allowing e.g. the use of torch models in pipelines.\n",
134+
"\n",
135+
"It's important to note the use `torch_load_kwargs={\"weights_only\": False}` in the model definition. This is necessary to ensure that pickling works correctly, which is required for parallel operation. It should also be possible to use `torch.serialization.add_safe_globals()`, following the suggestion in pytorch's documentation."
132136
]
133137
},
134138
{
@@ -166,6 +170,7 @@
166170
" optimizer=torch.optim.Adam,\n",
167171
" device=device,\n",
168172
" verbose=False,\n",
173+
" torch_load_kwargs={\"weights_only\": False},\n",
169174
")\n",
170175
"model.fit(*train.data());"
171176
]
@@ -532,7 +537,7 @@
532537
"cell_type": "markdown",
533538
"metadata": {},
534539
"source": [
535-
"### Maximum Sample Reuse Banzhaf\n",
540+
"## Maximum Sample Reuse Banzhaf\n",
536541
"\n",
537542
"Despite the previous results already being useful, we had to retrain the model a number of times and yet the variance of the value estimates was high. This has consequences for the stability of the top-k ranking of points, which decreases the applicability of the method. We will now use a different sampling method called Maximum Sample Reuse (MSR) which reuses every sample for updating the Banzhaf values. The method was introduced by the authors of Data-Banzhaf and is much more sample-efficient, as we will show."
538543
]
@@ -678,7 +683,7 @@
678683
"cell_type": "markdown",
679684
"metadata": {},
680685
"source": [
681-
"### Compare convergence speed of Banzhaf and MSR Banzhaf Values\n",
686+
"### Convergence speed of Banzhaf and MSR Banzhaf Values\n",
682687
"\n",
683688
"Conventional margin-based samplers produce require evaluating the utility twice to do one update of the value, and permutation samplers do instead $n+1$ evaluations for $n$ updates. Maximum Sample Reuse (MSR) updates instead all indices in every sample that the utility evaluates. We compare the convergence rates of these methods.\n",
684689
"\n",
@@ -857,7 +862,7 @@
857862
"cell_type": "markdown",
858863
"metadata": {},
859864
"source": [
860-
"### Similarity of the semi-values computed using different samplers"
865+
"## Similarity of the semi-values computed using different samplers"
861866
]
862867
},
863868
{

requirements-notebooks.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ dask==2024.8.0
33
distributed==2024.8.0
44
imblearn
55
pillow==10.4.0
6-
skorch>=1.1.0
6+
skorch>=1.2.0
77
torch>=2.8.0
88
torchvision>=0.23.0
99
transformers==4.44.2

0 commit comments

Comments
 (0)