Semi-Supervised GAN

Why Semi-Supervised Learning?

Semi-supervised learning in one of the most promising areas of practical application of GANs. Unlike supervised learning, where we need a label for every example in our dataset, and unsupervised learning, where no labels are used semi-supervised learning has a class for only a small subset of example.

The lack of labeled datasets is one of the main bottlenecks in machine learning research and practical applications. While unlabeled data is abundant (the Internet is essentially limitless source of unlabelled images, videos, and text) assigning class labels to them is often prohibitively expensive, impractical, and time-consuming.

Serving as a source of additional information that can be used for training, generative models proved useful in improving the accuracy of semi-supervised models.

What is Semi-Supervised GAN?

Semi-Supervised GAN (SGAN) is a generative adversarial network whose Discriminator is a multiclass classifier. Instead of distinguishing between only two classes (“real” and “fake”), it learns to distinguish between N + 1 classes, where N is the number of classes in the training dataset with one added for the fake examples produced by the Generator.

Turning the Discriminator from a binary to a multi-class classifier may seem like a trivial change but it implications are more far-reaching than may appear at the first glance.

As the diagram in Figure above indicates, the task of distinguishing between multiple classes impacts not only the Discriminator itself but also adds complexity to the SGAN architecture, its training process, and training objectives compared to the traditional GAN.

Architecture

SGAN Generator’s purpose is the same as in the original GAN: it takes in a vector of random numbers and produces fake examples whose goal is to be indistinguishable from the training dataset — no change here.

SGAN Discriminator, however, diverges considerably from the original GAN implementation. Instead of two, it receives three kinds of inputs: fake examples produced by the Generator (x^*), real examples without labels from the training dataset (x), and real examples with labels from the training dataset (x), and real examples with labels from the training dataset (x, y), where y denotes the label for the given categorize the input example into its corresponding class if the example is real, or reject the example as fake (which can be thought of as a special additional class).

Table below summarizes the key takeaways about the two SGAN subnetworks.


Training process

Recall that in a regular GAN, we train the Discriminator by computing the loss for D(x) and D(x*) and backpropagating the total loss to update the Discriminator’s trainable parameters to minimize the loss. The Generator is trained by backpropagating the Discriminator’s loss for D(x^*), seeking to maximize it, so that the fake examples it produces are misclassified as real.

To train SGAN, in addition to [katex]D(x)[/katex] and [katex]D(x^*)[/katex], we also have to compute loss for the supervised training examples: [katex]D((x, y))[/katex]. These losses correspond to the dual learning objective the SGAN discriminator has to grapple with: distinguishing real examples from the fake ones while also learning to classify real examples to their correct classes. Using the terminology from the original paper, these dual objectives correspond to two kinds of losses: the “Supervised Loss” and the “Unsupervised Loss”.

Training objective

The GANs we saw so far were all generative models. The goal of training them was to learn to produce realistic-looking examples and, consequently, the Generator network was of the primary interest. The main purpose of the Discriminator network was to help the Generator improve the quality of images it produced. At the end of training, we often disregarded the Discriminator and only used the fully-trained Generator to produce realistic-looking synthetic data.

In contrast, in SGAN we care primarily about the Discriminator. The goal of the training process is to make this network into a semi-supervised classifier whose accuracy is as close as possible to a fully-supervised classifier, while using only a small fraction of labeled examples for training. The goal of the Generator is to aid this process by serving as a source of additional information (i.e., the fake data it produces) that will help the Generator identify the correct class for each example. At the end of training, the Generator gets discarded and we use the trained Discriminator as a classifier.

Now that we learnt what motivated the SGAN and explained how it works, it is time to see the model in action by implementing one.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.