Skip to content

Commit 133581c

Browse files
committed
Adding pre-commit action to build markdown files derived from example jupyter notebooks.
1 parent 4a53ba3 commit 133581c

26 files changed

+6990
-174
lines changed

.github/notebook-reviews/docs/getting_started.md

Lines changed: 454 additions & 0 deletions
Large diffs are not rendered by default.

.github/notebook-reviews/examples/adversarial_training.md

Lines changed: 474 additions & 0 deletions
Large diffs are not rendered by default.

.github/notebook-reviews/examples/cifar10_resnet.md

Lines changed: 601 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
---
2+
jupyter:
3+
jupytext:
4+
text_representation:
5+
extension: .md
6+
format_name: markdown
7+
format_version: '1.3'
8+
jupytext_version: 1.17.3
9+
kernelspec:
10+
display_name: Python 3
11+
name: python3
12+
---
13+
14+
<!-- #region id="uJHywE_oL3j2" -->
15+
# Differentially private convolutional neural network on MNIST.
16+
17+
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/differentially_private_sgd.ipynb)
18+
19+
A large portion of this code is forked from the differentially private SGD
20+
example in the [JAX repo](
21+
https://github.com/jax-ml/jax/blob/main/examples/differentially_private_sgd.py).
22+
23+
To run the colab locally you need install the
24+
`dp-accounting`, `tensorflow`, `tensorflow-datasets`, packages via `pip`.
25+
26+
27+
[Differentially Private Stochastic Gradient Descent](https://arxiv.org/abs/1607.00133) requires clipping the per-example parameter
28+
gradients, which is non-trivial to implement efficiently for convolutional
29+
neural networks. The JAX XLA compiler shines in this setting by optimizing the
30+
minibatch-vectorized computation for convolutional architectures. Train time
31+
takes a few seconds per epoch on a commodity GPU.
32+
<!-- #endregion -->
33+
34+
```python id="VaYIiCnjL3j3"
35+
import warnings
36+
import dp_accounting
37+
import jax
38+
import jax.numpy as jnp
39+
from optax import contrib
40+
from optax import losses
41+
import optax
42+
from jax.example_libraries import stax
43+
import tensorflow as tf
44+
import tensorflow_datasets as tfds
45+
import matplotlib.pyplot as plt
46+
47+
# Shows on which platform JAX is running.
48+
print("JAX running on", jax.devices()[0].platform.upper())
49+
```
50+
51+
<!-- #region id="t7Dn8L_Uw0Yb" -->
52+
This table contains hyperparameters and the corresponding expected test accuracy.
53+
54+
55+
| DPSGD | LEARNING_RATE | NOISE_MULTIPLIER | L2_NORM_CLIP | BATCH_SIZE | NUM_EPOCHS | DELTA | FINAL TEST ACCURACY |
56+
| ------ | ------------- | ---------------- | ------------ | ---------- | ---------- | ----- | ------------------- |
57+
| False | 0.1 | NA | NA | 256 | 20 | NA | ~99% |
58+
| True | 0.25 | 1.3 | 1.5 | 256 | 15 | 1e-5 | ~95% |
59+
| True | 0.15 | 1.1 | 1.0 | 256 | 60 | 1e-5 | ~96.6% |
60+
| True | 0.25 | 0.7 | 1.5 | 256 | 45 | 1e-5 | ~97% |
61+
<!-- #endregion -->
62+
63+
```python id="jve2h810L3j3"
64+
# Whether to use DP-SGD or vanilla SGD:
65+
DPSGD = True
66+
# Learning rate for the optimizer:
67+
LEARNING_RATE = 0.25
68+
# Noise multiplier for DP-SGD optimizer:
69+
NOISE_MULTIPLIER = 1.3
70+
# L2 norm clip:
71+
L2_NORM_CLIP = 1.5
72+
# Number of samples in each batch:
73+
BATCH_SIZE = 256
74+
# Number of epochs:
75+
NUM_EPOCHS = 15
76+
# Probability of information leakage:
77+
DELTA = 1e-5
78+
```
79+
80+
<!-- #region id="iLGeV4y4DBkL" -->
81+
CIFAR10 and CIFAR100 are composed of 32x32 images with 3 channels (RGB). We'll now load the dataset using `tensorflow_datasets` and display a few of the first samples.
82+
<!-- #endregion -->
83+
84+
```python id="zynvtk4wDBkL"
85+
(train_loader, test_loader), info = tfds.load(
86+
"mnist", split=["train", "test"], as_supervised=True, with_info=True
87+
)
88+
89+
min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)
90+
train_loader = train_loader.map(min_max_rgb)
91+
test_loader = test_loader.map(min_max_rgb)
92+
93+
train_loader_batched = train_loader.shuffle(
94+
buffer_size=10_000, reshuffle_each_iteration=True
95+
).batch(BATCH_SIZE, drop_remainder=True)
96+
97+
NUM_EXAMPLES = info.splits["test"].num_examples
98+
test_batch = next(test_loader.batch(NUM_EXAMPLES, drop_remainder=True).as_numpy_iterator())
99+
```
100+
101+
```python id="o6In7oQ-0EhG"
102+
init_random_params, predict = stax.serial(
103+
stax.Conv(16, (8, 8), padding="SAME", strides=(2, 2)),
104+
stax.Relu,
105+
stax.MaxPool((2, 2), (1, 1)),
106+
stax.Conv(32, (4, 4), padding="VALID", strides=(2, 2)),
107+
stax.Relu,
108+
stax.MaxPool((2, 2), (1, 1)),
109+
stax.Flatten,
110+
stax.Dense(32),
111+
stax.Relu,
112+
stax.Dense(10),
113+
)
114+
```
115+
116+
<!-- #region id="j2OUgc6J0Jsl" -->
117+
This function computes the privacy parameter epsilon for the given number of steps and probability of information leakage `DELTA`.
118+
<!-- #endregion -->
119+
120+
```python id="43177TofzuOa"
121+
def compute_epsilon(steps):
122+
if NUM_EXAMPLES * DELTA > 1.:
123+
warnings.warn("Your delta might be too high.")
124+
q = BATCH_SIZE / float(NUM_EXAMPLES)
125+
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
126+
accountant = dp_accounting.rdp.RdpAccountant(orders)
127+
accountant.compose(dp_accounting.PoissonSampledDpEvent(
128+
q, dp_accounting.GaussianDpEvent(NOISE_MULTIPLIER)), steps)
129+
return accountant.get_epsilon(DELTA)
130+
```
131+
132+
```python id="W9mPtPvB0D3X"
133+
@jax.jit
134+
def loss_fn(params, batch):
135+
images, labels = batch
136+
logits = predict(params, images)
137+
return losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean(), logits
138+
139+
140+
@jax.jit
141+
def test_step(params, batch):
142+
images, labels = batch
143+
logits = predict(params, images)
144+
loss = losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
145+
accuracy = (logits.argmax(1) == labels).mean()
146+
return loss, accuracy * 100
147+
```
148+
149+
```python id="vOet-_860ysL"
150+
if DPSGD:
151+
tx = contrib.dpsgd(
152+
learning_rate=LEARNING_RATE, l2_norm_clip=L2_NORM_CLIP,
153+
noise_multiplier=NOISE_MULTIPLIER, seed=1337)
154+
else:
155+
tx = optax.sgd(learning_rate=LEARNING_RATE)
156+
157+
_, params = init_random_params(jax.random.PRNGKey(1337), (-1, 28, 28, 1))
158+
opt_state = tx.init(params)
159+
```
160+
161+
```python id="b-NmP7g01EdA"
162+
@jax.jit
163+
def train_step(params, opt_state, batch):
164+
grad_fn = jax.grad(loss_fn, has_aux=True)
165+
if DPSGD:
166+
# Inserts a dimension in axis 1 to use jax.vmap over the batch.
167+
batch = jax.tree.map(lambda x: x[:, None], batch)
168+
# Uses jax.vmap across the batch to extract per-example gradients.
169+
grad_fn = jax.vmap(grad_fn, in_axes=(None, 0))
170+
171+
grads, _ = grad_fn(params, batch)
172+
updates, new_opt_state = tx.update(grads, opt_state, params)
173+
new_params = optax.apply_updates(params, updates)
174+
return new_params, new_opt_state
175+
```
176+
177+
```python id="QMl9dnbJ1OtQ"
178+
accuracy, loss, epsilon = [], [], []
179+
180+
for epoch in range(NUM_EPOCHS):
181+
for batch in train_loader_batched.as_numpy_iterator():
182+
params, opt_state = train_step(params, opt_state, batch)
183+
184+
# Evaluates test accuracy.
185+
test_loss, test_acc = test_step(params, test_batch)
186+
accuracy.append(test_acc)
187+
loss.append(test_loss)
188+
print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, test accuracy: {test_acc}")
189+
190+
#
191+
if DPSGD:
192+
steps = (1 + epoch) * NUM_EXAMPLES // BATCH_SIZE
193+
eps = compute_epsilon(steps)
194+
epsilon.append(eps)
195+
```
196+
197+
```python id="9nsV-9_b2qca"
198+
if DPSGD:
199+
_, axs = plt.subplots(ncols=3, figsize=(9, 3))
200+
else:
201+
_, axs = plt.subplots(ncols=2, figsize=(6, 3))
202+
203+
axs[0].plot(accuracy)
204+
axs[0].set_title("Test accuracy")
205+
axs[1].plot(loss)
206+
axs[1].set_title("Test loss")
207+
208+
if DPSGD:
209+
axs[2].plot(epsilon)
210+
axs[2].set_title("Epsilon")
211+
212+
plt.tight_layout()
213+
```
214+
215+
```python id="1ubOEWod3OPj"
216+
print(f'Final accuracy: {accuracy[-1]}')
217+
```
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
---
2+
jupyter:
3+
jupytext:
4+
text_representation:
5+
extension: .md
6+
format_name: markdown
7+
format_version: '1.3'
8+
jupytext_version: 1.17.3
9+
kernelspec:
10+
display_name: venv
11+
language: python
12+
name: python3
13+
---
14+
15+
# Using the Muon Optimizer in Optax
16+
17+
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/contrib/muon.ipynb)
18+
19+
This notebook demonstrates how to use the `optax.contrib.muon` optimizer. We'll cover three main use cases:
20+
21+
1. **Default Muon:** Automatically applying Muon to 2D matrices and AdamW to all other parameters.
22+
2. **Masked Muon:** Using `muon_weight_mask` to explicitly select which parameters are optimized by Muon.
23+
3. **Muon with Reshaping:** Using `muon_weight_specs` to apply Muon to higher-dimensional parameters (tensors) by specifying how they should be reshaped.
24+
25+
```python
26+
from pprint import pprint
27+
28+
import jax
29+
import jax.numpy as jnp
30+
from jax import random
31+
32+
import optax
33+
```
34+
35+
```python
36+
# Create a sample PyTree of parameters with different dimensions
37+
keys = iter(random.split(random.key(0), 1024))
38+
params = {
39+
"layer1": {
40+
"w": jax.random.normal(next(keys), (128, 64)), # 2D matrix
41+
"b": jax.random.normal(next(keys), (64,)), # 1D vector
42+
},
43+
"layer2": {
44+
"w": jax.random.normal(next(keys), (64, 32)), # 2D matrix
45+
},
46+
"layer3_conv": {
47+
"w": jax.random.normal(next(keys), (4, 3, 3, 16)) # 4D tensor
48+
},
49+
}
50+
51+
52+
# A simple loss function: sum of squares of parameters.
53+
# The gradient of this loss is just the parameters themselves.
54+
@jax.jit
55+
def loss_fn(p):
56+
return sum(jnp.sum(x**2) for x in jax.tree.leaves(p))
57+
```
58+
59+
```python
60+
def print_state(state):
61+
print(
62+
"State variables using the muon transform ---------------------------"
63+
)
64+
pprint(
65+
{
66+
"".join(map(str, k)): "MUON"
67+
for k, v in jax.tree.flatten_with_path(state.inner_states["muon"])[
68+
0
69+
]
70+
if v.ndim > 0 and not str(k[-1]).endswith("ns_coeffs")
71+
}
72+
)
73+
print()
74+
print(
75+
"State variables using the adam transform ---------------------------"
76+
)
77+
pprint(
78+
{
79+
"".join(map(str, k)): "ADAM"
80+
for k, v in jax.tree.flatten_with_path(state.inner_states["adam"])[
81+
0
82+
]
83+
if v.ndim > 0 and not str(k[-1]).endswith("ns_coeffs")
84+
}
85+
)
86+
```
87+
88+
## 1. Default Muon Configuration
89+
90+
By default, `muon` partitions parameters based on their dimensionality. Parameters with `ndim == 2` (matrices) are optimized with Muon, while all others are handled by a standard AdamW optimizer.
91+
92+
```python
93+
# Use muon with default partitioning (ndim == 2 for muon)
94+
opt = optax.contrib.muon(learning_rate=1e-3)
95+
opt_state = opt.init(params)
96+
97+
print_state(opt_state)
98+
```
99+
100+
## 2. Using `muon_weight_dimension_numbers` for Explicit Selection and Higher-Rank Tensors
101+
102+
The core Muon algorithm (specifically, the Newton-Schulz iteration) operates on 2D matrices. To apply it to tensors of rank > 2, you must provide a `MuonDimensionNumbers` that tells the optimizer how to reshape the tensor into a 2D matrix (`(reduction_dim, output_dim)`).
103+
104+
- `reduction_axes`: A tuple of axis indices that will be flattened into the first dimension of the matrix.
105+
- `output_axes`: A tuple of axis indices that will be flattened into the second dimension.
106+
107+
Any remaining axes are treated as batch dimensions, and the operation is applied independently across them.
108+
109+
110+
You can override the default behavior using `muon_weight_dimension_numbers`. This is a PyTree with the same (or a prefix) structure as your parameters, containing `MuonDimensionNumbers` named tuples. If a leaf is a `MuonDimensionNumbers` tuple, the corresponding parameter is handled by Muon; if `None`, it's handled by AdamW.
111+
112+
Let's apply Muon *only* to `'layer1'`'s weights and use AdamW for everything else, including the other 2D matrix in `'layer2'`.
113+
114+
```python
115+
print("optax.contrib.MuonDimensionNumbers doctring:\n")
116+
print(optax.contrib.MuonDimensionNumbers.__doc__)
117+
```
118+
119+
```python
120+
# Mask to apply Muon ONLY to layer1's weights.
121+
weight_dim_nums = {
122+
"layer1": {
123+
# default for 2D is `optax.contrib.MuonDimensionNumbers(0, 1)`
124+
"w": optax.contrib.MuonDimensionNumbers(),
125+
"b": None,
126+
},
127+
"layer2": {
128+
"w": None,
129+
},
130+
"layer3_conv": {
131+
"w": None,
132+
},
133+
}
134+
135+
opt = optax.contrib.muon(
136+
learning_rate=1e-3, muon_weight_dimension_numbers=weight_dim_nums
137+
)
138+
opt_state = opt.init(params)
139+
print_state(opt_state)
140+
```
141+
142+
Let's apply Muon to our 4D convolutional weight tensor from `layer3_conv`.
143+
144+
```python
145+
# We want to apply Muon to the 4D convolutional kernel in 'layer3_conv'.
146+
# The shape is (4, 3, 3, 16). Let's treat the first three axes (4*3*3=36)
147+
# as the 'reduction' dimension and the last axis (16) as the 'output' dimension.
148+
149+
# Define the corresponding MuonDimensionNumbers for the selected tensors.
150+
# The structure must match parameters. Use None for non-Muon params.
151+
weight_dim_nums = {
152+
"layer1": {"w": optax.contrib.MuonDimensionNumbers((0,), (1,)), "b": None},
153+
"layer2": {"w": None},
154+
"layer3_conv": {
155+
"w": optax.contrib.MuonDimensionNumbers(
156+
reduction_axis=(0, 1, 2), output_axis=(3,)
157+
),
158+
},
159+
}
160+
161+
opt = optax.contrib.muon(
162+
learning_rate=1e-3, muon_weight_dimension_numbers=weight_dim_nums
163+
)
164+
opt_state = opt.init(params)
165+
166+
print_state(opt_state)
167+
```

0 commit comments

Comments
 (0)