Machine Unlearning: Stop the Evil AI from taking over your personal data
AI is scaling at an incredibly rapid pace and there’s a growing need to not only train models effectively but also to untrain them when necessary. This is where the concept of Machine Unlearning comes into play. In this blog post, we’ll explore the concept of Machine Unlearning, specifically focusing on an unlearning algorithm and its evaluation using a Membership Inference Attack (MIA).
🎯 Unlearning Algorithm
The essence of machine unlearning lies in developing algorithms that allow us to selectively forget certain information while retaining the knowledge we want to preserve. Let’s dive into an example unlearning algorithm.
The Retain and Forget Split
The unlearning algorithm begins by splitting the original training dataset into two subsets: the retain set and the forget set. The retain set contains the data we want to keep, typically much larger than the forget set. In this example, we use a 90% retain set and a 10% forget set:
forget_set, retain_set = keras.utils.split_dataset(train_ds.unbatch(), left_size=0.1)
forget_ds = forget_set.batch(BATCH_SIZE).prefetch(AUTOTUNE)
retain_ds = retain_set.batch(BATCH_SIZE).prefetch(AUTOTUNE)
Unlearning by Fine-Tuning
One simple unlearning algorithm is called unlearning by fine-tuning. This approach starts with a pre-trained model and then optimizes it for a few epochs on the retain set. Here’s the unlearning function:
def unlearning(net, retain, forget, validation):
"""Unlearning by fine-tuning.
net : keras.Model.
Pre-trained model to use as a base for unlearning.
retain : tf.data.Dataset.
Dataset loader for the retain set.
forget : tf.data.Dataset.
Dataset loader for the forget set.
validation : tf.data.Dataset.
Dataset loader for the validation set.
net : Updated model.
# ... Setup and compile the model ...
The algorithm fine-tunes the model on the retain set, and after a few epochs, it returns an updated model (
model_ft) that should…