Semi-Supervised Learning with DCGANs
25 Aug 2018I recently wanted to try semi-supervised learning on a research problem. In Improved Techniques for Training GANs the authors show how a deep convolutional generative adversarial network, originally intended for unsupervised learning, may be adapted for semi-supervised learning. It wasn’t immediately clear to me how the equations in Section 5 of the paper are equivalent to OpenAI’s implementation, so I decided to go through the derivation myself. In this post, I’ll show how the supervised and unsupervised loss functions in the implementation may be derived from the equations provided in the paper. If you’re already familiar with the background, skip ahead to the supervised loss or unsupervised loss derivation.
I’ve previously written about adversarial learning and others have written several good posts about semi-supervised learning with deep convolutional generative adversarial networks (DCGANs), so it would be superfluous for me to cover the architecture and optimization. Thus, I’ll only provide enough background necessary to understand the loss function. Those interested in a more thorough discussion with TensorFlow code might want to read Thalles Silva’s blog post, Semi-supervised learning with Generative Adversarial Networks (GANs).
Background
Unlike supervised learning, where we have labels for all training examples, and unsupervised learning, where we don’t have any labels at all, in semi-supervised learning we have labels for some, but not all, training examples. Clearly, we should extract any useful information that is available from the unlabeled examples, even if we don’t know their labels explicitly. Thus, training a semi-supervised DCGAN for classification consists has four parts, namely
- the discriminator must predict the correct label for those training examples with labels (exactly like supervised learning),
- the discriminator must verify that the unlabeled training examples were drawn from the original data distribution, i.e., belong to one of the known classes even if we don’t know which one,
- the discriminator must claim that samples drawn from the generator are fake, i.e., do not belong to one of the known classes, and finally,
- the generator must continually improve its fake samples to fool the discriminator into thinking that they are real.
Taken together, the idea is that the discriminator acts as a conventional multi-class classifier, while simultaneously learning to map the unlabeled training examples to the same feature distribution.
Deriving the Loss Functions
To realize the four points described above with a differentiable loss function, the authors of Improved Techniques for Training GANs propose using a $K+1$-class classifier, where the first $K$ outputs correspond to the probabilities of the $K$ classes. Effectively, this means that the sum over the first $K$ outputs may be interpreted as the probability that the given example is real, and the output at $K+1$ as the probability that the example is fake. They note, however, that such a classifier is over-parameterized. This is analogous to how a binary classifier requires only a single output neuron to produce the probabilities of the two classes. Thus, the same result may be achieved using a $K$-class classifier as the discriminator. Although the derivation is straightforward, it’s easy to see this with a small code example:
Supervised Loss
With a $K$-class classifier, the supervised loss simply becomes the categorical cross-entropy loss. Let $\mathbf{x} \in [-1, 1]^{H \times W \times 3}$ be a single training example (an RGB-image normalized to the same range as the generator’s output) and $\mathbf{y} \in \{0, 1\}^{K} \mid \sum\limits_{i=1}^{K}y_i = 1$ (a one-hot vector indicating to which class $\mathbf{x}$ belongs). The discriminator takes the input $\mathbf{x}$ and produces the $K$-dimensional vector of logits $l(\mathbf{x})$ at the final layer. The predicted probability $\hat{y}_j$ for class $j$ is computed using the softmax function
which we substitute into the categorical cross-entropy loss
Because $\mathbf{y}$ is a one-hot vector, we know that it is zero everywhere except for the index representing the class to which $\mathbf{x}$ belongs. Let this index be $j’$, such that $y_i = 0 \quad\forall_{i \neq j’}$ and $y_{j’} = 1$. All terms in the outer summation become zero except for $j = j’$, and so the loss function may be simplified to
After taking the average over all training examples in the minibatch, this gives us the supervised loss function provided in the implementation.
Unsupervised Loss
Because we’re simply going to average the loss over the minibatch anyway, we compute the loss for a single example to keep things simple. Let $\mathbf{x}_{real} \in [-1, 1]^{H \times W \times 3}$ be a single training example (for which we don’t know the label) and $\mathbf{x}_{fake} \in [-1, 1]^{H \times W \times 3}$ be a single example sampled from the generator. For the unsupervised loss, the discriminator must maximize the log-probability that $\mathbf{x}_{real}$ is real and minimize the log-probability that $\mathbf{x}_{fake}$ is real,
where $D(\mathbf{x})$ is simply the softmax over the $K$ logits and a dummy logit fixed to $l_{k+1}(\mathbf{x}) = 0$ that appears as $e^0 = 1$ in the denominator to represent the “fake” class, i.e.,
where
is the sum of the exponents of the logits representing the “real” classes. Substituting and simplifying, we get
Noting that $\log\left(\sum\limits_{k=1}^{K}e^{l_k(\mathbf{x})} + 1\right) = \text{Softplus}\left(\log\sum\limits_{k=1}^{K}e^{l_k(\mathbf{x})}\right)$, we get
After averaging each term over all examples in the minibatch and scaling by $0.5$ to account for the fact that we have effectively doubled the batch size relative to the supervised loss, we get the unsupervised loss defined in the implementation.
Conclusion
Initially developed for unsupervised representation learning, deep convolutional generative adversarial networks have been shown also to be effective for semi-supervised learning. Although the combination of supervised and unsupervised losses presented in Section 5 of Improved Techniques for Training GANs may initially seem not to resemble those in the official implementation, careful derivation makes the equivalence clear.