42.122. Summary#

This notebook shows how to train Denoising Difussion Models.

The code has been adapted and curated from this tutorial by Andras Beres.

42.122.1. Hyperparams#

import numpy as np
import math
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import os

# data
diffusion_steps = 20
image_size = 32

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# optimization
batch_size = 64
num_epochs = 10
learning_rate = 1e-3
weight_decay = 1e-4
ema = 0.999

42.122.2. Dataset#

def preprocess_image(data):
    # center crop image
    height = tf.shape(data["image"])[0]
    width = tf.shape(data["image"])[1]
    crop_size = tf.minimum(height, width)
    image = tf.image.crop_to_bounding_box(
        data["image"],
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    )
    # resize and clip
    # for image downsampling it is important to turn on antialiasing
    image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
    return tf.clip_by_value(image / 255.0, 0.0, 1.0)

def prepare_dataset(split):
    return (
        tfds.load('mnist', split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .repeat(1)
        .shuffle(10000)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

# load dataset
train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")

42.122.3. Denoising Network#

We will use the Residual U-Net model.

42.122.4. TODO: can we use something simpler?#

embedding_max_frequency = 1000.0
embedding_dims = 64

def sinusoidal_embedding(x):
    embedding_min_frequency = 1.0
    frequencies = tf.exp(
        tf.linspace(
            tf.math.log(embedding_min_frequency),
            tf.math.log(embedding_max_frequency),
            embedding_dims // 2,
        )
    )
    angular_speeds = 2.0 * math.pi * frequencies
    embeddings = tf.concat(
        [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
    )
    return embeddings

42.122.5. Custom Residual Network#

def get_network_custom(image_size, block_depth=17, output_channels=1):
    # use the correct number of channels
    noisy_images = tf.keras.Input(shape=(image_size, image_size, output_channels))
    noise_variances = tf.keras.Input(shape=(1, 1, 1))

    e = tf.keras.layers.Lambda(sinusoidal_embedding)(noise_variances)
    e = tf.keras.layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
    x = tf.keras.layers.Conv2D(32, kernel_size=1)(noisy_images)
    x = tf.keras.layers.Concatenate()([x, e])

    x = tf.keras.layers.Conv2D(64, 3, padding='same', activation=tf.nn.relu)(x)
    for layers in range(2, block_depth+1):
        x = tf.keras.layers.BatchNormalization(center=False, scale=False)(x)
        x = tf.keras.layers.Conv2D(
            64, 3,
            padding='same', name='conv%d' % layers,
            activation=tf.keras.activations.swish,
            use_bias=False
        )(x)

    x = tf.keras.layers.Conv2D(output_channels, kernel_size=1, kernel_initializer="zeros")(x)
    return tf.keras.Model([noisy_images, noise_variances], x, name="simple-residual-net")

42.122.6. Residual U-Net#

widths = [32, 64, 96, 128]
block_depth = 2

def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[3]
        if input_width == width:
            residual = x
        else:
            residual = tf.keras.layers.Conv2D(width, kernel_size=1)(x)
        x = tf.keras.layers.BatchNormalization(center=False, scale=False)(x)
        x = tf.keras.layers.Conv2D(
            width, kernel_size=3, padding="same", activation=tf.keras.activations.swish
        )(x)
        x = tf.keras.layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = tf.keras.layers.Add()([x, residual])
        return x

    return apply


def DownBlock(width, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)
            skips.append(x)
        x = tf.keras.layers.AveragePooling2D(pool_size=2)(x)
        return x

    return apply


def UpBlock(width, block_depth):
    def apply(x):
        x, skips = x
        x = tf.keras.layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = tf.keras.layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(width)(x)
        return x

    return apply


def get_network(image_size, widths, block_depth):
    # use the correct number of channels
    noisy_images = tf.keras.Input(shape=(image_size, image_size, 1))
    noise_variances = tf.keras.Input(shape=(1, 1, 1))

    e = tf.keras.layers.Lambda(sinusoidal_embedding)(noise_variances)
    e = tf.keras.layers.UpSampling2D(size=image_size, interpolation="nearest")(e)

    x = tf.keras.layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
    x = tf.keras.layers.Concatenate()([x, e])

    skips = []
    for width in widths[:-1]:
        x = DownBlock(width, block_depth)([x, skips])

    for _ in range(block_depth):
        x = ResidualBlock(widths[-1])(x)

    for width in reversed(widths[:-1]):
        x = UpBlock(width, block_depth)([x, skips])

    x = tf.keras.layers.Conv2D(1, kernel_size=1, kernel_initializer="zeros")(x)

    return tf.keras.Model([noisy_images, noise_variances], x, name="residual_unet")

42.122.7. Difussion Model#

class DiffusionModel(tf.keras.Model):
    def __init__(self, network):
        super().__init__()
        self.normalizer = tf.keras.layers.Normalization()
        self.network = network
        self.ema_network = tf.keras.models.clone_model(self.network)

    def compile(self, **kwargs):
        super().compile(**kwargs)
        self.noise_loss_tracker = tf.keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = tf.keras.metrics.Mean(name="i_loss")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker]

    def denormalize(self, images):
        images = self.normalizer.mean + images * self.normalizer.variance**0.5
        return tf.clip_by_value(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = tf.acos(max_signal_rate)
        end_angle = tf.acos(min_signal_rate)
        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
        # angles -> signal and noise rates
        signal_rates = tf.cos(diffusion_angles)
        noise_rates = tf.sin(diffusion_angles)
        # note that their squared sum is always: sin^2(x) + cos^2(x) = 1
        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # the exponential moving average weights are used at evaluation
        if training:
            network = self.network
        else:
            network = self.ema_network
        # predict noise component and calculate the image component using it
        pred_noises = network([noisy_images, noise_rates**2], training=training)
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, steps):
        # reverse diffusion = sampling
        batch = initial_noise.shape[0]
        step_size = 1.0 / steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images
            diffusion_times = tf.ones((batch, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=False
            )

            # this new noisy image will be used in the next step
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
            next_diffusion_times
            )
            next_noisy_images = (
            next_signal_rates * pred_images + next_noise_rates * pred_noises
           )
        return pred_images

    def generate(self, num_images, steps):
        # noise -> images -> denormalized images
        initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 1))
        generated_images = self.reverse_diffusion(initial_noise, steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=True)
        noises = tf.random.normal(shape=images.shape)
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)

        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            # train the network to separate noisy images to their components
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)

        # track the exponential moving averages of weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        return {m.name: m.result() for m in self.metrics}

42.122.8. Complete Model#

Chose one of the residual networks.

network = get_network_custom(image_size,block_depth=10)
# network = get_network(image_size,widths,block_depth)
print(network.summary())
Model: "simple-residual-net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_2 (InputLayer)           [(None, 1, 1, 1)]    0           []                               
                                                                                                  
 input_1 (InputLayer)           [(None, 32, 32, 1)]  0           []                               
                                                                                                  
 lambda (Lambda)                (None, 1, 1, 64)     0           ['input_2[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (None, 32, 32, 32)   64          ['input_1[0][0]']                
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 32, 32, 64)   0           ['lambda[0][0]']                 
                                                                                                  
 concatenate (Concatenate)      (None, 32, 32, 96)   0           ['conv2d[0][0]',                 
                                                                  'up_sampling2d[0][0]']          
                                                                                                  
 conv2d_1 (Conv2D)              (None, 32, 32, 64)   55360       ['concatenate[0][0]']            
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 32, 64)  128         ['conv2d_1[0][0]']               
 alization)                                                                                       
                                                                                                  
 conv2 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization[0][0]']    
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 32, 32, 64)  128         ['conv2[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv3 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_1[0][0]']  
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 32, 32, 64)  128         ['conv3[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv4 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_2[0][0]']  
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 32, 32, 64)  128         ['conv4[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv5 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_3[0][0]']  
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 32, 32, 64)  128         ['conv5[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv6 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_4[0][0]']  
                                                                                                  
 batch_normalization_5 (BatchNo  (None, 32, 32, 64)  128         ['conv6[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv7 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_5[0][0]']  
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 32, 32, 64)  128         ['conv7[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv8 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_6[0][0]']  
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 32, 32, 64)  128         ['conv8[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv9 (Conv2D)                 (None, 32, 32, 64)   36864       ['batch_normalization_7[0][0]']  
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 32, 32, 64)  128         ['conv9[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 conv10 (Conv2D)                (None, 32, 32, 64)   36864       ['batch_normalization_8[0][0]']  
                                                                                                  
 conv2d_2 (Conv2D)              (None, 32, 32, 1)    65          ['conv10[0][0]']                 
                                                                                                  
==================================================================================================
Total params: 388,417
Trainable params: 387,265
Non-trainable params: 1,152
__________________________________________________________________________________________________
None
model = DiffusionModel(network)
model.compile(
    optimizer=tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay),
    loss=tf.keras.losses.mean_absolute_error,
)
model.normalizer.adapt(train_dataset)
model.fit(
    train_dataset,
    epochs=num_epochs,
)
Epoch 1/10
937/937 [==============================] - 2197s 2s/step - n_loss: 0.1291 - i_loss: 0.3230
Epoch 2/10
937/937 [==============================] - 1940s 2s/step - n_loss: 0.0934 - i_loss: 0.2160
Epoch 3/10
937/937 [==============================] - 1929s 2s/step - n_loss: 0.0881 - i_loss: 0.2033
Epoch 4/10
937/937 [==============================] - 1960s 2s/step - n_loss: 0.0864 - i_loss: 0.1974
Epoch 5/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0841 - i_loss: 0.1926
Epoch 6/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0835 - i_loss: 0.1908
Epoch 7/10
937/937 [==============================] - 1926s 2s/step - n_loss: 0.0823 - i_loss: 0.1864
Epoch 8/10
937/937 [==============================] - 1948s 2s/step - n_loss: 0.0813 - i_loss: 0.1817
Epoch 9/10
937/937 [==============================] - 1923s 2s/step - n_loss: 0.0813 - i_loss: 0.1825
Epoch 10/10
937/937 [==============================] - 1927s 2s/step - n_loss: 0.0801 - i_loss: 0.1787
<keras.callbacks.History at 0x25c6dfaf370>

42.122.9. Visualize#

num_rows = 2
num_cols = 3

generated_images = model.generate(
    num_images=num_rows * num_cols,
    steps=diffusion_steps,
)

plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
for row in range(num_rows):
    for col in range(num_cols):
        index = row * num_cols + col
        plt.subplot(num_rows, num_cols, index + 1)
        plt.imshow(generated_images[index])
        plt.axis("off")

plt.tight_layout()
../../../_images/denoising-difussion-model_19_0.png

42.122.10. Acknowledgments#

Thanks to Maciej Skorski for creating Denoising Difussion Model. It inspires the majority of the content in this chapter.