Generative Adversarial Networks (GANs)
Contents
42.120. Generative Adversarial Networks (GANs)#
42.120.1. Loading data#
import os
import requests
import zipfile
directory = "./tmp/all-dogs"
zip_url = "https://static-1300131294.cos.ap-shanghai.myqcloud.com/data/deep-learning/Gan/all-dogs.zip"
os.makedirs(directory, exist_ok=True)
response = requests.get(zip_url)
zip_filename = os.path.join(directory, "all-dogs.zip")
with open(zip_filename, "wb") as file:
file.write(response.content)
print("ZIP File successfully downloaded")
with zipfile.ZipFile(zip_filename, "r") as zip_ref:
zip_ref.extractall(directory)
print("ZIP File successfully unzipped")
os.remove(zip_filename)
print(os.listdir("./tmp"))
42.120.2. Importing the libraries#
from __future__ import print_function
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tqdm import tqdm_notebook as tqdm
42.120.3. Some dogs#
The Stanford Dogs dataset contains images of 120 breeds of dogs from around the world.
PATH = './tmp/all-dogs/dogs/'
images = os.listdir(PATH)
print(f'There are {len(os.listdir(PATH))} pictures of dogs.')
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 10))
for indx, axis in enumerate(axes.flatten()):
rnd_indx = np.random.randint(0, len(os.listdir(PATH)))
img = plt.imread(PATH + images[rnd_indx])
imgplot = axis.imshow(img)
axis.set_title(images[rnd_indx])
axis.set_axis_off()
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
42.120.4. Image Preprocessing#
batch_size = 32
image_size = 64
random_transforms = [transforms.ColorJitter(), transforms.RandomRotation(degrees=20)]
transform = transforms.Compose([transforms.Resize(64),
transforms.CenterCrop(64),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply(random_transforms, p=0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data = datasets.ImageFolder('./tmp', transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True,
batch_size=batch_size)
imgs, label = next(iter(train_loader))
imgs = imgs.numpy().transpose(0, 2, 3, 1)
for i in range(5):
plt.imshow(imgs[i])
plt.show()
42.120.5. Weights#
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
42.120.6. Generator#
class G(nn.Module):
def __init__(self):
super(G, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1, bias=False),
nn.Tanh()
)
def forward(self, input):
output = self.main(input)
return output
netG = G()
netG.apply(weights_init)
42.120.7. Discriminator#
class D(nn.Module):
def __init__(self):
super(D, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, stride=2, padding=1, bias=False),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(256, 512, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(512, 1, 4, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
output = self.main(input)
return output.view(-1)
netD = D()
netD.apply(weights_init)
42.120.8. Another setup#
class Generator(nn.Module):
def __init__(self, nz=128, channels=3):
super(Generator, self).__init__()
self.nz = nz
self.channels = channels
def convlayer(n_input, n_output, k_size=4, stride=2, padding=0):
block = [
nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(n_output),
nn.ReLU(inplace=True),
]
return block
self.model = nn.Sequential(
*convlayer(self.nz, 1024, 4, 1, 0),
*convlayer(1024, 512, 4, 2, 1),
*convlayer(512, 256, 4, 2, 1),
*convlayer(256, 128, 4, 2, 1),
*convlayer(128, 64, 4, 2, 1),
nn.ConvTranspose2d(64, self.channels, 3, 1, 1),
nn.Tanh()
)
def forward(self, z):
z = z.view(-1, self.nz, 1, 1)
img = self.model(z)
return img
class Discriminator(nn.Module):
def __init__(self, channels=3):
super(Discriminator, self).__init__()
self.channels = channels
def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, bn=False):
block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]
if bn:
block.append(nn.BatchNorm2d(n_output))
block.append(nn.LeakyReLU(0.2, inplace=True))
return block
self.model = nn.Sequential(
*convlayer(self.channels, 32, 4, 2, 1),
*convlayer(32, 64, 4, 2, 1),
*convlayer(64, 128, 4, 2, 1, bn=True),
*convlayer(128, 256, 4, 2, 1, bn=True),
nn.Conv2d(256, 1, 4, 1, 0, bias=False),
)
def forward(self, imgs):
logits = self.model(imgs)
out = torch.sigmoid(logits)
return out.view(-1, 1)
42.120.9. Training#
EPOCH = 0
LR = 0.001
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(0.5, 0.999))
for epoch in range(EPOCH):
for i, data in enumerate(dataloader, 0):
# 1st Step: Updating the weights of the neural network of the discriminator
netD.zero_grad()
# Training the discriminator with a real image of the dataset
real,_ = data
input = Variable(real)
target = Variable(torch.ones(input.size()[0]))
output = netD(input)
errD_real = criterion(output, target)
# Training the discriminator with a fake image generated by the generator
noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
fake = netG(noise)
target = Variable(torch.zeros(input.size()[0]))
output = netD(fake.detach())
errD_fake = criterion(output, target)
# Backpropagating the total error
errD = errD_real + errD_fake
errD.backward()
optimizerD.step()
# 2nd Step: Updating the weights of the neural network of the generator
netG.zero_grad()
target = Variable(torch.ones(input.size()[0]))
output = netD(fake)
errG = criterion(output, target)
errG.backward()
optimizerG.step()
# 3rd Step: Printing the losses and saving the real images and the generated images of the minibatch every 100 steps
print('[%d/%d][%d/%d] Loss_D: %.4f; Loss_G: %.4f' % (epoch, EPOCH, i, len(dataloader), errD.item(), errG.item()))
if i % 100 == 0:
vutils.save_image(real, '%s/real_samples.png' % "./results", normalize=True)
fake = netG(noise)
vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize=True)
42.120.10. Best public training#
42.120.10.1. Parameters#
batch_size = 32
LR_G = 0.001
LR_D = 0.0005
beta1 = 0.5
epochs = 100
real_label = 0.9
fake_label = 0
nz = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42.120.10.2. Initialize models and optimizers#
netG = Generator(nz).to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=LR_D, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR_G, betas=(beta1, 0.999))
fixed_noise = torch.randn(25, nz, 1, 1, device=device)
G_losses = []
D_losses = []
epoch_time = []
def plot_loss (G_losses, D_losses, epoch):
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss - EPOCH "+ str(epoch))
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
42.120.11. Show generated images#
def show_generated_img(n_images=5):
sample = []
for _ in range(n_images):
noise = torch.randn(1, nz, 1, 1, device=device)
gen_image = netG(noise).to("cpu").clone().detach().squeeze(0)
gen_image = gen_image.numpy().transpose(1, 2, 0)
sample.append(gen_image)
figure, axes = plt.subplots(1, len(sample), figsize = (64,64))
for index, axis in enumerate(axes):
axis.axis('off')
image_array = sample[index]
axis.imshow(image_array)
plt.show()
plt.close()
42.120.12. Training Loop#
for epoch in range(epochs):
start = time.time()
for ii, (real_images, train_labels) in tqdm(enumerate(train_loader), total=len(train_loader)):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_images = real_images.to(device)
batch_size = real_images.size(0)
labels = torch.full((batch_size, 1), real_label, device=device)
output = netD(real_images)
errD_real = criterion(output, labels)
errD_real.backward()
D_x = output.mean().item()
# train with fake
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise)
labels.fill_(fake_label)
output = netD(fake.detach())
errD_fake = criterion(output, labels)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
labels.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, labels)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
if (ii+1) % (len(train_loader)//2) == 0:
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
% (epoch + 1, epochs, ii+1, len(train_loader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
plot_loss (G_losses, D_losses, epoch)
G_losses = []
D_losses = []
if epoch % 10 == 0:
show_generated_img()
epoch_time.append(time.time()- start)
print (">> average EPOCH duration = ", np.mean(epoch_time))
42.120.13. Generation example#
show_generated_img(7)
if not os.path.exists('../output_images'):
os.mkdir('../output_images')
im_batch_size = 50
n_images=10000
for i_batch in tqdm(range(0, n_images, im_batch_size)):
gen_z = torch.randn(im_batch_size, nz, 1, 1, device=device)
gen_images = netG(gen_z)
images = gen_images.to("cpu").clone().detach()
images = images.numpy().transpose(0, 2, 3, 1)
for i_image in range(gen_images.size(0)):
save_image(gen_images[i_image, :, :, :], os.path.join('../output_images', f'image_{i_batch+i_image:05d}.png'))
fig = plt.figure(figsize=(25, 16))
# display 10 images from each class
for i, j in enumerate(images[:32]):
ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])
plt.imshow(j)
42.120.13.1. Save models#
torch.save(netG.state_dict(), 'generator.pth')
torch.save(netD.state_dict(), 'discriminator.pth')
42.120.14. Acknowledgement#
Thanks to jesucristo for creating GAN Introduction. It inspired the majority of the content in this article.