Generative Adversarial Network
Contents
%%html
<!-- The customized css for the slides -->
<link rel="stylesheet" type="text/css" href="../styles/python-programming-introduction.css"/>
<link rel="stylesheet" type="text/css" href="../styles/basic.css"/>
<link rel="stylesheet" type="text/css" href="../../assets/styles/basic.css" />
# Install the necessary dependencies
import os
import sys
!{sys.executable} -m pip install --quiet pandas scikit-learn numpy matplotlib jupyterlab_myst ipython tensorflow
43.21. Generative Adversarial Network#
43.21.1. What is a gerative model? (1/3)#
Many machine learning systems look at some kind of complicated input (say, an image) and produce a simple output (a label like, “cat”)
By contrast, the goal of a generative model is something like the opposite: take a small piece of input—perhaps a few random numbers—and produce a complex output, like an image of a realistic-looking face
43.21.2. What is a gerative model? (2/3)#
“Generative” describes a class of statistical models that contrasts with discriminative models.
Informally:
Generative models can generate new data instances.
Discriminative models discriminate between different kinds of data instances.
A generative model could generate new photos of animals that look like real animals, while a discriminative model could tell a dog from a cat.
43.21.3. What is a gerative model? (3/3)#
A generative model can model a distribution by producing convincing “fake” data that looks like it’s drawn from that distribution.
A generative model for images might capture correlations like “things that look like boats are probably going to appear near things that look like water” and “eyes are unlikely to appear on foreheads.” These are very complicated distributions.
43.21.4. What is a generative adversarial network (GAN)?#
An especially effective type of generative model, introduced in 2014 by Ian Goodfellow, which has been a subject of intense interest in the machine learning community
43.21.5. How does a GAN work?#
A generative adversarial network (GAN) has two parts:
The generator learns to generate plausible data. The generated instances become negative training examples for the discriminator.
The discriminator learns to distinguish the generator’s fake data from real data. The discriminator penalizes the generator for producing implausible results.
43.21.6. How does a GAN work?#
When training begins, the generator produces obviously fake data, and the discriminator quickly learns to tell that it’s fake:
43.21.7. How does a GAN work?#
As training progresses, the generator gets closer to producing output that can fool the discriminator:
43.21.8. How does a GAN work?#
Finally, if generator training goes well, the discriminator gets worse at telling the difference between real and fake. It starts to classify fake data as real, and its accuracy decreases.
43.21.9. How does a GAN work?#
Here’s a picture of the whole system:
43.21.10. How does a GAN work?#
Both the generator and the discriminator are neural networks. The generator output is connected directly to the discriminator input. Through backpropagation, the discriminator’s classification provides a signal that the generator uses to update its weights.
43.21.11. Time to play!#
43.21.12. Let’s code a GAN!#
import tensorflow
import tensorflow as tf
from tensorflow import keras
## Packages & Libraries
from keras.layers import Dense, Dropout, Input, ReLU
from keras.models import Model, Sequential
from keras.optimizers import Adam
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
## Data Import
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print("X train shape: ", x_train.shape)
print("Y train shape: ", y_train.shape)
print("X test shape: ", x_test.shape)
print("Y test shape: ", y_test.shape)
plt.imshow(x_train[108])
# x_train to (-1, 1)
x_train = (x_train.astype(np.float32)-127.5)/127.5
#reshape data
x_train = x_train.reshape(x_train.shape[0], x_train.shape[1] * x_train.shape[2])
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1] * x_test.shape[2])
print("X train shape: ", x_train.shape)
print("X test shape: ", x_test.shape)
43.21.13. Create Model#
# create generator
def create_generator():
generator = Sequential()
generator.add(Dense(units = 512, input_dim = 100))
generator.add(ReLU())
generator.add(Dense(units = 512))
generator.add(ReLU())
generator.add(Dense(units = 1024))
generator.add(ReLU())
generator.add(Dense(units = 784, activation = "tanh"))
generator.compile(loss = "binary_crossentropy", optimizer = Adam(lr = 0.0001, beta_1 = 0.5))
return generator
g = create_generator()
g.summary()
# create discriminator
def create_discriminator():
discriminator = Sequential()
discriminator.add(Dense(units = 1024, input_dim = 784))
discriminator.add(ReLU())
discriminator.add(Dropout(0.4))
discriminator.add(Dense(units = 512))
discriminator.add(ReLU())
discriminator.add(Dropout(0.4))
discriminator.add(Dense(units = 256))
discriminator.add(ReLU())
discriminator.add(Dense(units = 1, activation = "sigmoid"))
discriminator.compile(loss = "binary_crossentropy", optimizer = Adam(lr = 0.0001, beta_1 = 0.5))
return discriminator
d = create_discriminator()
d.summary()
# GANs
def create_gan(discriminator, generator):
discriminator.trainable = False
gan_input = Input(shape = (100,))
x = generator(gan_input)
gan_output = discriminator(x)
gan = Model(inputs = gan_input, outputs = gan_output)
gan.compile(loss = "binary_crossentropy", optimizer = "adam")
return gan
gan = create_gan(d, g)
gan.summary()
epochs = 2 # should be large, e.g. 50
batch_size = 256
for e in range(epochs):
for _ in range(batch_size):
noise = np.random.normal(0, 1, [batch_size, 100])
# generated image batch
generated_images = g.predict(noise)
# real image batch
image_batch = x_train[np.random.randint(low = 0, high = x_train.shape[0], size = batch_size)]
x = np.concatenate([image_batch, generated_images])
# allocation discriminator predictions
y_dis = np.zeros(batch_size * 2)
y_dis[:batch_size] = 1
d.trainable = True
d.train_on_batch(x, y_dis)
noise = np.random.normal(0, 1, [batch_size, 100])
y_gen = np.ones(batch_size)
d.trainable = False
gan.train_on_batch(noise, y_gen)
print("Epoch: ", e)
g.save("g.hdf5")
noise = np.random.normal(loc = 0, scale = 1, size = [100, 100])
generated_images = g.predict(noise)
generated_images = generated_images.reshape(100, 28, 28)
plt.imshow(generated_images[66], interpolation="nearest")
plt.axis("off")
plt.show()
gp = keras.models.load_model('g-pretrained.hdf5')
noise = np.random.normal(loc = 0, scale = 1, size = [100, 100])
generated_images = gp.predict(noise)
generated_images = generated_images.reshape(100, 28, 28)
plt.imshow(generated_images[66], interpolation="nearest")
plt.axis("off")
plt.show()