Semi-Supervised GAN

Tutorial: Implementing Semi-Supervised GAN

Architecture Diagram

Below is a high-level diagram of the SGAN model we will implement in this tutorial. It is a bit more complex than the general, conceptual diagram we introduced at the beginning of this chapter.

To solve the multi-class classification problem of distinguishing between the real labels, the Discriminator uses the softmax function, which gives probability distribution over a specified number of classes — in our case, 10. The higher the probability assigned to a given label, the more confident the Discriminator is that the example belongs to the given class. To compute the classification error, we use cross-entropy loss which measures the difference between the output probabilities and the target, one-hot-encoded labels.

To output the real-vs-fake probability, the Discriminator uses the sigmoid activation function and trains its parameters by backpropagating the binary cross entropy loss.

Implementation

The Generator

The Generator network is the same as the on we implemented for the DCGAN in Chapter 4. Using transposed convolution layers, the Generator transforms the input random noise vector into 28 x 28 x 1 image.

def build_generator(img_shape, z_dim): model = Sequential() # Reshape input into 7x7x256 tensor via a fully connected layer model.add(Dense(256 * 7 * 7, input_dim=z_dim)) model.add(Reshape((7, 7, 256))) # Transposed convolution layer, from 7x7x256 into 14x14x128 tensor model.add(Conv2DTranspose( 128, kernel_size=3, strides=2, padding=’same’)) # Batch normalization model.add(BatchNormalization()) # Leaky ReLU model.add(LeakyReLU(alpha=0.01)) # Transposed convolution layer, from 14x14x128 into 28x28x1 tensor model.add(Conv2DTranspose( 1, kernel_size=3, strides=2, padding=’same’)) # Tanh activation model.add(Activation(‘tanh’)) z = Input(shape=(z_dim, )) img = model(z) return Model(z, img)

The Discriminator

The Discriminator is the most complex part of the model. Recall that the SGAN Discriminator has a dual objective:

  1. Distinguish real examples form fake ones. For this, the SGAN Discriminator uses the sigmoid function, outputting a single output probability for binary classification.
  2. For the real examples, accurately classify their label. For this, the SGAN Discriminator uses the softmax function, outputting a vector of probabilities, one of each of the target classes.
The Core discriminator network

We start by defining the core Discriminator network. As you may notice, the model below is very similar to the ConvNet-based Discriminator; in fact, it is exactly the same all the way until the 3 x 3 x 128 convolutional layer, its batch normalization, and Leaky ReLU activation.

After that layer, we added a so-called “dropout”. Dropout is a regularization technique which helps prevent overfitting by randomly dropping neurons and their connections from the neural network during training. This forces the remaining neurons to reduce their dependence and develop a more general representation of the underlying data. The fraction of the neurons to be randomly dropped is specified by the rate parameter, which is set to 0.5 in our implementation: model.add(Dropout(0.5)). We added dropout because of the increased complexity of the SGAN classification task and to improve the model’s ability to generalize from only 100 labeled examples:

The supervised discriminator

In the code block below, we take the core Discriminator network implemented above and use it to build the supervised portion of the Discriminator model:

1
2
3
4
5
6
7
8
9
10
def build_discriminator_supervised(discriminator_net):

    model = Sequential()

    model.add(discriminator_net)

    # Softmax giving probability distribution over the real classes
    model.add(Activation('softmax'))

    return model
The unsupervised discriminator

The code below implements the unsupervised portion of the Discriminator model on top of the core Discriminator network. Notice the predict(x) function in which we transform the output of the 10 neurons into a binary, real-vs-fake prediction.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def build_discriminator_unsupervised(discriminator_net):

    model = Sequential()

    model.add(discriminator_net)

    def predict(x):
        # Transform distribution over real classes into a binary real-vs-fake probability
        prediction = 1.0 - (1.0 / (K.sum(K.exp(x), axis=-1, keepdims=True) + 1.0))
        return prediction

    # Add the 'real-vs-fake' output neuron defined above
    model.add(Lambda(predict))

    return model

Build the Model

Next, we build and compile the Discriminator and Generator models. Notice the use of categorical_crossentropy and binary_crossentropy loss functions for the supervised loss and the unsupervised loss, respectively.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# -------------------------
# Build the Discriminator
# -------------------------

discriminator_net = build_discriminator_net(img_shape)

# Compile the Discriminator for supervised training
discriminator_supervised = build_discriminator_supervised(discriminator_net)
discriminator_supervised.compile(
    loss='categorical_crossentropy', metrics=['accuracy'], optimizer=Adam())

# Compile the Discriminator for unsupervised training
discriminator_unsupervised = build_discriminator_unsupervised(discriminator_net)
discriminator_unsupervised.compile(loss='binary_crossentropy', optimizer=Adam())

# ---------------------
#  Build the Generator
# ---------------------

generator = build_generator(img_shape, z_dim)

# Keep GAN model with fixed Discriminator to train the Generator
def combined(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)

combined = combined(generator, discriminator_unsupervised)
combined.compile(loss='binary_crossentropy', optimizer=Adam())

Training

SGAN Training algorithm

for each training iteration do:

  1. Train the Discriminator (Supervised):
    1. Take a random mini-batch of labeled real examples (x, y)
    2. Compute D((x, y)) for the given mini-batch and backpropagate the multiclass classification loss to update \theta^{(D)} to minimize the loss.
  2. Train the Discriminator (Unsupervised):
    1. Take a random mini-batch of unlabeled real examples x
    2. Compute D(x) for the given mini-batch and backpropagate the binary classification loss to update \theta^{(D)} to minimize the loss.
  3. Train the Generator:
    1. Take a mini-batch of random noise vectors z and generate a mini-batch of fake examples: G(z) = x^*
    2. Compute D(x^*) for the given mini-batch and backpropagate the binary classification loss to update \theta^{(G)} to maximize the loss

end for

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
d_accuracies = []
d_losses = []

def train(iterations, batch_size, sample_interval):

    # Labels for real and fake examples
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for iteration in range(iterations):

        # -----------------------
        # Train the Discriminator
        # -----------------------

        # Labeled examples
        imgs, labels = dataset.batch_labeled(batch_size)

        # One-hot encode labels
        labels = to_categorical(labels, num_classes=num_classes)

        # Unlabeled examples
        imgs_unlabeled = dataset.batch_unlabeled(batch_size)

        # Generate a batch of fake images
        z = np.random.normal(0, 1, (batch_size, z_dim))
        gen_imgs = generator.predict(z)

        # Train on real labeled examples
        d_loss_supervised, accuracy = discriminator_supervised.train_on_batch(imgs, labels)

        # Train on real unlabeled examples
        d_loss_real = discriminator_unsupervised.train_on_batch(imgs_unlabeled, real)

        # Train on fake examples
        d_loss_fake = discriminator_unsupervised.train_on_batch(gen_imgs, fake)

        d_loss_unsupervised = 0.5 * np.add(d_loss_real, d_loss_fake)

        # --------------------
        # Train the Generator
        # --------------------

        # Generate a batch of fake images
        z = np.random.normal(0, 1, (batch_size, z_dim))
        gen_imgs = generator.predict(z)

        # Train Generator
        g_loss = combined.train_on_batch(z, np.ones((batch_size, 1)))

        # Save Discriminator supervised classification loss and accuracy
        d_losses.append(d_loss_supervised)
        d_accuracies.append(accuracy)

        if iteration % sample_interval == 0:
            # Output training progress
            print("%d [D loss supervised: %.4f, acc.: %.2f%%]" %
                   (iteration, d_loss_supervised, 100 * accuracy))

Comparison to a Fully-Supervised Classifier

To make the comparison as fair as possible, we use the same network architecture for the fully-supervised classifier as the one used for the supervised Discriminator training. The idea is that this will allow us to isolate the improvement to the classifier’s ability to generalize that was achieved through the GAN-enabled semi-supervised learning.

1
2
3
4
# Fully supervised classifier with the same network architecture as the SGAN Discriminator
mnist_classifier = build_discriminator_supervised(build_discriminator_net(img_shape))
mnist_classifier.compile(
       loss='categorical_crossentropy', metrics=['accuracy'], optimizer=Adam())

With a lot more training data, the fully-supervised classifier’s ability to generalize improves dramatically. Using the same setup and training the fully-supervised classifier with 10,000 labeled examples (100 times as many as we originally used), we achieve an accuracy of 98.16%. But that would no longer be a semi-supervised setting.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.