Suggested Solution
Reusing the code from the repository you shared, here are some suggested modifications to train a classifier along your generator and discriminator (their architectures and other losses are left untouched):
from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
def lenet_classifier_model(nb_classes):
# Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
# Replace with your favorite classifier...
model = Sequential()
model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(180, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(100, activation='relu', init='he_normal'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes, activation='softmax', init='he_normal'))
def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
inputs = Input((IN_CH, img_cols, img_rows))
x_generator = generator(inputs)
merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
discriminator.trainable = False
x_discriminator = discriminator(merged)
classifier.trainable = False
x_classifier = classifier(x_generator)
model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
return model
def train(BATCH_SIZE):
(X_train, Y_train, LABEL_train) = get_data('train') # replace with your data here
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
discriminator = discriminator_model()
generator = generator_model()
classifier = lenet_classifier_model(6)
generator.summary()
discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
generator, discriminator, classifier)
d_optim = Adagrad(lr=0.005)
g_optim = Adagrad(lr=0.005)
generator.compile(loss='mse', optimizer="rmsprop")
discriminator_and_classifier_on_generator.compile(
loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
optimizer="rmsprop")
discriminator.trainable = True
discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
classifier.trainable = True
classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")
for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
for index in range(int(X_train.shape[0] / BATCH_SIZE)):
image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] # replace with your data here
generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
if index % 20 == 0:
image = combine_images(generated_images)
image = image * 127.5 + 127.5
image = np.swapaxes(image, 0, 2)
cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
# Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
# Training D:
real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
axis=1)
fake_pairs = np.concatenate(
(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
X = np.concatenate((real_pairs, fake_pairs))
y = np.zeros((20, 1, 64, 64)) # [1] * BATCH_SIZE + [0] * BATCH_SIZE
d_loss = discriminator.train_on_batch(X, y)
print("batch %d d_loss : %f" % (index, d_loss))
discriminator.trainable = False
# Training C:
c_loss = classifier.train_on_batch(image_batch, label_batch)
print("batch %d c_loss : %f" % (index, c_loss))
classifier.trainable = False
# Train G:
g_loss = discriminator_and_classifier_on_generator.train_on_batch(
X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :],
[image_batch, np.ones((10, 1, 64, 64)), label_batch])
discriminator.trainable = True
classifier.trainable = True
print("batch %d g_loss : %f" % (index, g_loss[1]))
if index % 20 == 0:
generator.save_weights('generator', True)
discriminator.save_weights('discriminator', True)
Theoretical Details
I believe there are some misunderstandings regarding how conditional GANs work and what is the discriminators role in such schemes.
Role of the Discriminator
In the min-max game which is GAN training [4], the discriminator D
is playing against the generator G
(the network you actually care about) so that under D
's scrutiny, G
becomes better at outputting realistic results.
For that, D
is trained to tell apart real samples from samples from G
; while G
is trained to fool D
by generating realistic results / results following the target distribution.
Note: in the case of conditional GANs, i.e. GANs mapping an input sample from one domain A
(e.g. real picture) to another domain B
(e.g. sketch), D
is usually fed with the pairs of samples stacked
together and has to discriminate "real" pairs (input sample from A
+
corresponding target sample from B
) and "fake" pairs (input sample
from A
+ corresponding output from G
) [1, 2]
Training a conditional generator against D
(as opposed to simply training G
alone, with a L1/L2 loss only e.g. DAE) improves the sampling capability of G
, forcing it to output crisp, realistic results instead of trying to average the distribution.
Even though discriminators can have multiple sub-networks to cover other tasks (see next paragraphs), D
should keep at least one sub-network/output to cover its main task: telling real samples from generated ones apart. Asking D
to regress further semantic information (e.g. classes) alongside may interfere with this main purpose.
Note: D
output is often not a simple scalar / boolean. It is common to have a discriminator (e.g. PatchGAN [1, 2]) returning a matrix of
probabilities, evaluating how realistic patches made from its input
are.
Conditional GANs
Traditional GANs are trained in an unsupervised manner to generate realistic data (e.g. images) from a random noise vector as input. [4]
As previously mentioned, conditional GANs have further input conditions. Along/instead of the noise vector, they take for input a sample from a domain A
and return a corresponding sample from a domain B
. A
can be a completely different modality, e.g. B = sketch image
while A = discrete label
; B = volumetric data
while A = RGB image
, etc. [3]
Such GANs can also be conditioned by multiples inputs, e.g. A = real image + discrete label
while B = sketch image
. A famous work introducing such methods is InfoGAN [5]. It presents how to condition GANs on multiple continuous or discrete inputs (e.g. A = digit class + writing type
, B = handwritten digit image
), using a more advanced discriminator which has for 2nd task to force G
to maximize the mutual-information between its conditioning inputs and its corresponding outputs.
Maximizing the Mutual Information for cGANs
InfoGAN discriminator has 2 heads/sub-networks to cover its 2 tasks [5]:
- One head
D1
does the traditional real/generated discrimination -- G
has to minimize this result, i.e. it has to fool D1
so that it can't tell apart real form generated data;
- Another head
D2
(also named Q
network) tries to regress the input A
information -- G
has to maximize this result, i.e. it has to output data which "show" the requested semantic information (c.f. mutual-information maximization between G
conditional inputs and its outputs).
You can find a Keras implementation here for instance: https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan.
Several works are using similar schemes to improve control over what a GAN is generating, by using provided labels and maximizing the mutual information between these inputs and G
outputs [6, 7]. The basic idea is always the same though:
- Train
G
to generate elements of domain B
, given some inputs of domain(s) A
;
- Train
D
to discriminate "real"/"fake" results -- G
has to minimize this;
- Train
Q
(e.g. a classifier ; can share layers with D
) to estimate the original A
inputs from B
samples -- G
has to maximize this).
Wrapping Up
In your case, it seems you have the following training data:
- real images
Ia
- corresponding sketch images
Ib
- corresponding class labels
c
And you want to train a generator G
so that given an image Ia
and its class label c
, it outputs a proper sketch image Ib'
.
All in all, that's a lot of information you have, and you can supervise your training both on the conditioned images and the conditioned labels...
Inspired from the aforementioned methods [1, 2, 5, 6, 7], here is a possible way of using all this information to train your conditional G
:
Network G
:
- Inputs:
Ia
+ c
- Output:
Ib'
- Architecture: up-to-you (e.g. U-Net, ResNet, ...)
- Losses: L1/L2 loss between
Ib'
& Ib
, -D
loss, Q
loss
Network D
:
- Inputs:
Ia
+ Ib
(real pair), Ia
+ Ib'
(fake pair)
- Output: "fakeness" scalar/matrix
- Architecture: up-to-you (e.g. PatchGAN)
- Loss: cross-entropy on the "fakeness" estimation
Network Q
:
- Inputs:
Ib
(real sample, for training Q
), Ib'
(fake sample, when back-propagating through G
)
- Output:
c'
(estimated class)
- Architecture: up-to-you (e.g. LeNet, ResNet, VGG, ...)
- Loss: cross-entropy between
c
and c'
Training Phase:
- Train
D
on a batch of real pairs Ia
+ Ib
then on a batch of fake pairs Ia
+ Ib'
;
- Train
Q
on a batch of real samples Ib
;
- Fix
D
and Q
weights;
- Train
G
, passing its generated outputs Ib'
to D
and Q
to back-propagate through them.
Note: this is a really rough architecture description. I'd recommend going through the literature ([1, 5, 6, 7] as a good start) to get
more details and maybe a more elaborate solution.
References
- Isola, Phillip, et al. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017). http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
- Zhu, Jun-Yan, et al. "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593 (2017). http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
- Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014). https://arxiv.org/pdf/1411.1784
- Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
- Chen, Xi, et al. "Infogan: Interpretable representation learning by information maximizing generative adversarial nets." Advances in Neural Information Processing Systems. 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
- Lee, Minhyeok, and Junhee Seok. "Controllable Generative Adversarial Network." arXiv preprint arXiv:1708.00598 (2017). https://arxiv.org/pdf/1708.00598.pdf
- Odena, Augustus, Christopher Olah, and Jonathon Shlens. "Conditional image synthesis with auxiliary classifier gans." arXiv preprint arXiv:1610.09585 (2016). http://proceedings.mlr.press/v70/odena17a/odena17a.pdf