|
| 1 | +--- |
| 2 | +jupyter: |
| 3 | + jupytext: |
| 4 | + text_representation: |
| 5 | + extension: .md |
| 6 | + format_name: markdown |
| 7 | + format_version: '1.3' |
| 8 | + jupytext_version: 1.17.3 |
| 9 | + kernelspec: |
| 10 | + display_name: Python 3 |
| 11 | + name: python3 |
| 12 | +--- |
| 13 | + |
| 14 | +<!-- #region id="uJHywE_oL3j2" --> |
| 15 | +# Differentially private convolutional neural network on MNIST. |
| 16 | + |
| 17 | +[](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/differentially_private_sgd.ipynb) |
| 18 | + |
| 19 | +A large portion of this code is forked from the differentially private SGD |
| 20 | +example in the [JAX repo]( |
| 21 | +https://github.com/jax-ml/jax/blob/main/examples/differentially_private_sgd.py). |
| 22 | + |
| 23 | +To run the colab locally you need install the |
| 24 | +`dp-accounting`, `tensorflow`, `tensorflow-datasets`, packages via `pip`. |
| 25 | + |
| 26 | + |
| 27 | +[Differentially Private Stochastic Gradient Descent](https://arxiv.org/abs/1607.00133) requires clipping the per-example parameter |
| 28 | +gradients, which is non-trivial to implement efficiently for convolutional |
| 29 | +neural networks. The JAX XLA compiler shines in this setting by optimizing the |
| 30 | +minibatch-vectorized computation for convolutional architectures. Train time |
| 31 | +takes a few seconds per epoch on a commodity GPU. |
| 32 | +<!-- #endregion --> |
| 33 | + |
| 34 | +```python id="VaYIiCnjL3j3" |
| 35 | +import warnings |
| 36 | +import dp_accounting |
| 37 | +import jax |
| 38 | +import jax.numpy as jnp |
| 39 | +from optax import contrib |
| 40 | +from optax import losses |
| 41 | +import optax |
| 42 | +from jax.example_libraries import stax |
| 43 | +import tensorflow as tf |
| 44 | +import tensorflow_datasets as tfds |
| 45 | +import matplotlib.pyplot as plt |
| 46 | + |
| 47 | +# Shows on which platform JAX is running. |
| 48 | +print("JAX running on", jax.devices()[0].platform.upper()) |
| 49 | +``` |
| 50 | + |
| 51 | +<!-- #region id="t7Dn8L_Uw0Yb" --> |
| 52 | +This table contains hyperparameters and the corresponding expected test accuracy. |
| 53 | + |
| 54 | + |
| 55 | +| DPSGD | LEARNING_RATE | NOISE_MULTIPLIER | L2_NORM_CLIP | BATCH_SIZE | NUM_EPOCHS | DELTA | FINAL TEST ACCURACY | |
| 56 | +| ------ | ------------- | ---------------- | ------------ | ---------- | ---------- | ----- | ------------------- | |
| 57 | +| False | 0.1 | NA | NA | 256 | 20 | NA | ~99% | |
| 58 | +| True | 0.25 | 1.3 | 1.5 | 256 | 15 | 1e-5 | ~95% | |
| 59 | +| True | 0.15 | 1.1 | 1.0 | 256 | 60 | 1e-5 | ~96.6% | |
| 60 | +| True | 0.25 | 0.7 | 1.5 | 256 | 45 | 1e-5 | ~97% | |
| 61 | +<!-- #endregion --> |
| 62 | + |
| 63 | +```python id="jve2h810L3j3" |
| 64 | +# Whether to use DP-SGD or vanilla SGD: |
| 65 | +DPSGD = True |
| 66 | +# Learning rate for the optimizer: |
| 67 | +LEARNING_RATE = 0.25 |
| 68 | +# Noise multiplier for DP-SGD optimizer: |
| 69 | +NOISE_MULTIPLIER = 1.3 |
| 70 | +# L2 norm clip: |
| 71 | +L2_NORM_CLIP = 1.5 |
| 72 | +# Number of samples in each batch: |
| 73 | +BATCH_SIZE = 256 |
| 74 | +# Number of epochs: |
| 75 | +NUM_EPOCHS = 15 |
| 76 | +# Probability of information leakage: |
| 77 | +DELTA = 1e-5 |
| 78 | +``` |
| 79 | + |
| 80 | +<!-- #region id="iLGeV4y4DBkL" --> |
| 81 | +CIFAR10 and CIFAR100 are composed of 32x32 images with 3 channels (RGB). We'll now load the dataset using `tensorflow_datasets` and display a few of the first samples. |
| 82 | +<!-- #endregion --> |
| 83 | + |
| 84 | +```python id="zynvtk4wDBkL" |
| 85 | +(train_loader, test_loader), info = tfds.load( |
| 86 | + "mnist", split=["train", "test"], as_supervised=True, with_info=True |
| 87 | +) |
| 88 | + |
| 89 | +min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label) |
| 90 | +train_loader = train_loader.map(min_max_rgb) |
| 91 | +test_loader = test_loader.map(min_max_rgb) |
| 92 | + |
| 93 | +train_loader_batched = train_loader.shuffle( |
| 94 | + buffer_size=10_000, reshuffle_each_iteration=True |
| 95 | +).batch(BATCH_SIZE, drop_remainder=True) |
| 96 | + |
| 97 | +NUM_EXAMPLES = info.splits["test"].num_examples |
| 98 | +test_batch = next(test_loader.batch(NUM_EXAMPLES, drop_remainder=True).as_numpy_iterator()) |
| 99 | +``` |
| 100 | + |
| 101 | +```python id="o6In7oQ-0EhG" |
| 102 | +init_random_params, predict = stax.serial( |
| 103 | + stax.Conv(16, (8, 8), padding="SAME", strides=(2, 2)), |
| 104 | + stax.Relu, |
| 105 | + stax.MaxPool((2, 2), (1, 1)), |
| 106 | + stax.Conv(32, (4, 4), padding="VALID", strides=(2, 2)), |
| 107 | + stax.Relu, |
| 108 | + stax.MaxPool((2, 2), (1, 1)), |
| 109 | + stax.Flatten, |
| 110 | + stax.Dense(32), |
| 111 | + stax.Relu, |
| 112 | + stax.Dense(10), |
| 113 | +) |
| 114 | +``` |
| 115 | + |
| 116 | +<!-- #region id="j2OUgc6J0Jsl" --> |
| 117 | +This function computes the privacy parameter epsilon for the given number of steps and probability of information leakage `DELTA`. |
| 118 | +<!-- #endregion --> |
| 119 | + |
| 120 | +```python id="43177TofzuOa" |
| 121 | +def compute_epsilon(steps): |
| 122 | + if NUM_EXAMPLES * DELTA > 1.: |
| 123 | + warnings.warn("Your delta might be too high.") |
| 124 | + q = BATCH_SIZE / float(NUM_EXAMPLES) |
| 125 | + orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64)) |
| 126 | + accountant = dp_accounting.rdp.RdpAccountant(orders) |
| 127 | + accountant.compose(dp_accounting.PoissonSampledDpEvent( |
| 128 | + q, dp_accounting.GaussianDpEvent(NOISE_MULTIPLIER)), steps) |
| 129 | + return accountant.get_epsilon(DELTA) |
| 130 | +``` |
| 131 | + |
| 132 | +```python id="W9mPtPvB0D3X" |
| 133 | +@jax.jit |
| 134 | +def loss_fn(params, batch): |
| 135 | + images, labels = batch |
| 136 | + logits = predict(params, images) |
| 137 | + return losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean(), logits |
| 138 | + |
| 139 | + |
| 140 | +@jax.jit |
| 141 | +def test_step(params, batch): |
| 142 | + images, labels = batch |
| 143 | + logits = predict(params, images) |
| 144 | + loss = losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean() |
| 145 | + accuracy = (logits.argmax(1) == labels).mean() |
| 146 | + return loss, accuracy * 100 |
| 147 | +``` |
| 148 | + |
| 149 | +```python id="vOet-_860ysL" |
| 150 | +if DPSGD: |
| 151 | + tx = contrib.dpsgd( |
| 152 | + learning_rate=LEARNING_RATE, l2_norm_clip=L2_NORM_CLIP, |
| 153 | + noise_multiplier=NOISE_MULTIPLIER, seed=1337) |
| 154 | +else: |
| 155 | + tx = optax.sgd(learning_rate=LEARNING_RATE) |
| 156 | + |
| 157 | +_, params = init_random_params(jax.random.PRNGKey(1337), (-1, 28, 28, 1)) |
| 158 | +opt_state = tx.init(params) |
| 159 | +``` |
| 160 | + |
| 161 | +```python id="b-NmP7g01EdA" |
| 162 | +@jax.jit |
| 163 | +def train_step(params, opt_state, batch): |
| 164 | + grad_fn = jax.grad(loss_fn, has_aux=True) |
| 165 | + if DPSGD: |
| 166 | + # Inserts a dimension in axis 1 to use jax.vmap over the batch. |
| 167 | + batch = jax.tree.map(lambda x: x[:, None], batch) |
| 168 | + # Uses jax.vmap across the batch to extract per-example gradients. |
| 169 | + grad_fn = jax.vmap(grad_fn, in_axes=(None, 0)) |
| 170 | + |
| 171 | + grads, _ = grad_fn(params, batch) |
| 172 | + updates, new_opt_state = tx.update(grads, opt_state, params) |
| 173 | + new_params = optax.apply_updates(params, updates) |
| 174 | + return new_params, new_opt_state |
| 175 | +``` |
| 176 | + |
| 177 | +```python id="QMl9dnbJ1OtQ" |
| 178 | +accuracy, loss, epsilon = [], [], [] |
| 179 | + |
| 180 | +for epoch in range(NUM_EPOCHS): |
| 181 | + for batch in train_loader_batched.as_numpy_iterator(): |
| 182 | + params, opt_state = train_step(params, opt_state, batch) |
| 183 | + |
| 184 | + # Evaluates test accuracy. |
| 185 | + test_loss, test_acc = test_step(params, test_batch) |
| 186 | + accuracy.append(test_acc) |
| 187 | + loss.append(test_loss) |
| 188 | + print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, test accuracy: {test_acc}") |
| 189 | + |
| 190 | + # |
| 191 | + if DPSGD: |
| 192 | + steps = (1 + epoch) * NUM_EXAMPLES // BATCH_SIZE |
| 193 | + eps = compute_epsilon(steps) |
| 194 | + epsilon.append(eps) |
| 195 | +``` |
| 196 | + |
| 197 | +```python id="9nsV-9_b2qca" |
| 198 | +if DPSGD: |
| 199 | + _, axs = plt.subplots(ncols=3, figsize=(9, 3)) |
| 200 | +else: |
| 201 | + _, axs = plt.subplots(ncols=2, figsize=(6, 3)) |
| 202 | + |
| 203 | +axs[0].plot(accuracy) |
| 204 | +axs[0].set_title("Test accuracy") |
| 205 | +axs[1].plot(loss) |
| 206 | +axs[1].set_title("Test loss") |
| 207 | + |
| 208 | +if DPSGD: |
| 209 | + axs[2].plot(epsilon) |
| 210 | + axs[2].set_title("Epsilon") |
| 211 | + |
| 212 | +plt.tight_layout() |
| 213 | +``` |
| 214 | + |
| 215 | +```python id="1ubOEWod3OPj" |
| 216 | +print(f'Final accuracy: {accuracy[-1]}') |
| 217 | +``` |
0 commit comments