A highly interesting application of generative adversarial networks is image generation. One popular and challenging use case of image generation is creating anime faces. It's a great project that one can build while getting started with GANs. So, let's dive straight ahead and build one ourselves!
GANs stand for generative adversarial networks and are basically a class of machine learning models that can generate new data similar to the current dataset.
They consist of two components called a generator and a discriminator.
Generator | Discriminator |
The generator creates new data samples by trying to mimic the real data. | The discriminator tries to distinguish between our real and generator data. |
GANs work by training both of these components in a competitive way. The generator aims to make such data that can trick the discriminator into believing that it is real, and the discriminator aims to classify real and generated data correctly.
As we progress through the training, the generator improves at generating realistic data, while the discriminator improves at distinguishing real from this generated data. At the end, the generator's results get very accurate.
We begin by importing the required libraries and modules to access their functionalities in the code.
os
is needed for interacting with the operating system e.g., reading files.
numpy
is needed for numerical operations.
matplotlib
is needed for plotting and creating visualizations.
random
is needed for introducing randomness in data manipulation.
warnings
is needed to avoid cluttering the output.
tensorflow
is needed for all of the deep-learning aspects of the code.
import osimport numpy as npimport matplotlib.pyplot as pltimport matplotlib.image as mpimgimport randomimport warningsimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras.models import Sequential, Modelfrom tensorflow.keras.preprocessing.image import load_img, img_to_arrayfrom tensorflow.keras import layersfrom tensorflow.keras.optimizers import Adamfrom tensorflow.keras.losses import BinaryCrossentropywarnings.filterwarnings('ignore')
For training our model, we have used the anime dataset found on Kaggle. These images have been saved in data
in the same directory.
In the following snippet, we iterate through the files within data
and create a list called all_images
. We store the paths of each image file in this list.
all_images = []for image in os.listdir("data"):image_path = os.path.join("data", image)all_images.append(image_path)
To check if our data has been loaded properly, we display a few images from the all_images
list. In this case, we show six of the images using Matplotlib's plt.subplot()
.
selected_images = random.sample(all_images, 6)fig, ax = plt.subplots(1, 6, figsize=(35, 35))for i, image_path in enumerate(selected_images):img = mpimg.imread(image_path)ax[i].imshow(img)ax[i].axis('off')plt.show()
The output will contain six images from the dataset, for instance:
We save the training images data in train_images
by converting the images we just loaded above into arrays. These images are used to feed the relevant information to the model in order to train it.
In the following snippet, we save the image data, normalize the pixel values, and reshape the images into a consistent shape.
train_images = [img_to_array(load_img(path)) for path in all_images]train_images = np.array(train_images, dtype='float32')train_images = (train_images - 127.5) / 127.5train_images = train_images.reshape(train_images.shape[0], 64, 64, 3)
For our GAN model, we now create the generator.
Read about it in detail in the generator for AI anime answer.
generator = Sequential(name='generator')generator.add(layers.Dense(8 * 8 * 512, input_dim=100))generator.add(layers.ReLU())generator.add(layers.Reshape((8, 8, 512)))generator.add(layers.Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same', kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02)))generator.add(layers.ReLU())generator.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02)))generator.add(layers.ReLU())generator.add(layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02)))generator.add(layers.ReLU())generator.add(layers.Conv2D(3, (4, 4), padding='same', activation='tanh'))generator.summary()
To train the model as to whether the images being created are realistic, we will now make the discriminator.
Read about it in detail in the discriminator for AI anime answer.
discriminator = Sequential(name='discriminator')input_shape = (64, 64, 3)discriminator.add(layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=input_shape))discriminator.add(layers.BatchNormalization())discriminator.add(layers.LeakyReLU(alpha=0.2))discriminator.add(layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same'))discriminator.add(layers.BatchNormalization())discriminator.add(layers.LeakyReLU(alpha=0.2))discriminator.add(layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same'))discriminator.add(layers.BatchNormalization())discriminator.add(layers.LeakyReLU(alpha=0.2))discriminator.add(layers.Flatten())discriminator.add(layers.Dropout(0.3))discriminator.add(layers.Dense(1, activation='sigmoid'))discriminator.summary()
We define our custom DCGAN class inheriting from keras.Model
. This class shows the training process between a generator and a discriminator.
The __init__
constructor initializes the DCGAN class with the generator, discriminator, and latent dimension. Let's go through the important keywords here.
self.generator
and self.discriminator
: Store generator and discriminator models.
self.latent_dim
: Holds the dimension of random noise for the generator.
self.discriminator_loss_metric
and self.generator_loss_metric
: Metrics to track losses.
The metrics
property returns loss metrics for monitoring.
The compile
method configures the optimizer and loss function for training.
The train_step
method defines a training step for the DCGAN. Then it computes the batch size and generates random noise. It also uses gradient tape for differentiation.
Inside the gradient tape context, we calculate the following:
Discriminator loss for real images and labels with slight noise.
Discriminator loss for fake images.
Combined discriminator loss as the average.
Gradients of discriminator loss and update the discriminator's variables.
Generator loss using fake images and gradients.
Update the generator's variables.
Next, we update the discriminator and generator loss metrics.
We finally return the updated loss metrics for monitoring.
class DCGAN(keras.Model):def __init__(self, generator, discriminator, latent_dim):super(DCGAN, self).__init__()self.generator = generatorself.discriminator = discriminatorself.latent_dim = latent_dimself.discriminator_loss_metric = keras.metrics.Mean(name='discriminator_loss')self.generator_loss_metric = keras.metrics.Mean(name='generator_loss')@propertydef metrics(self):return [self.discriminator_loss_metric, self.generator_loss_metric]def compile(self, g_opt, d_opt, loss_fn):super(DCGAN, self).compile()self.g_opt = g_optself.d_opt = d_optself.loss_fn = loss_fndef train_step(self, real_imgs):batch_size = tf.shape(real_imgs)[0]noise = tf.random.normal(shape=(batch_size, self.latent_dim))with tf.GradientTape() as tape:pred_real = self.discriminator(real_imgs, training=True)real_labels = tf.ones((batch_size, 1)) + 0.05 * tf.random.uniform((batch_size, 1))discriminator_loss_real = self.loss_fn(real_labels, pred_real)fake_imgs = self.generator(noise, training=True)pred_fake = self.discriminator(fake_imgs, training=True)fake_labels = tf.zeros((batch_size, 1))discriminator_loss_fake = self.loss_fn(fake_labels, pred_fake)discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2gradients = tape.gradient(discriminator_loss, self.discriminator.trainable_variables)self.d_opt.apply_gradients(zip(gradients, self.discriminator.trainable_variables))labels = tf.ones((batch_size, 1))with tf.GradientTape() as tape:fake_imgs = self.generator(noise, training=True)pred_fake = self.discriminator(fake_imgs, training=True)generator_loss = self.loss_fn(labels, pred_fake)gradients = tape.gradient(generator_loss, self.generator.trainable_variables)self.g_opt.apply_gradients(zip(gradients, self.generator.trainable_variables))self.discriminator_loss_metric.update_state(discriminator_loss)self.generator_loss_metric.update_state(generator_loss)return {'discriminator_loss': self.discriminator_loss_metric.result(),'generator_loss': self.generator_loss_metric.result()}
The DCGANMonitor class inherits from keras.callbacks.Callback
. It initializes the callback with the number of images to generate num_imgs
and the dimensions of the latent space latent_dim
. self.noise
is the random noise for image generation.
The on_epoch_end
method basically gets the generated images from the generator, and after scaling them, it displays the plot using a 5 x 5 grid.
The on_train_end
method saves the trained generator and discriminator as "generator.h5" and "discriminator.h5" so that we can use them again later on without spending a lot of time on training.
class DCGANMonitor(keras.callbacks.Callback):def __init__(self, num_imgs=25, latent_dim=100):self.num_imgs = num_imgsself.latent_dim = latent_dimself.noise = tf.random.normal([25, latent_dim])def on_epoch_end(self, epoch, logs=None):gen_imgs = self.model.generator(self.noise)gen_imgs = (gen_imgs * 127.5) + 127.5fig = plt.figure(figsize=(8, 8))for i in range(self.num_imgs):plt.subplot(5, 5, i+1)img = array_to_img(gen_imgs[i])plt.imshow(img)plt.axis('off')plt.show()def on_train_end(self, logs=None):self.model.generator.save('generator.h5')self.model.discriminator.save('discriminator.h5')dcgan = DCGAN(generator, discriminator, 100)
We finally compile our dcgan
and specify the learning_rate
, beta_1
, and loss_fn
values.
dcgan.compile(g_opt=Adam(learning_rate=0.0003, beta_1=0.5),d_opt=Adam(learning_rate=0.0001, beta_1=0.5),loss_fn=BinaryCrossentropy())
The number of epochs N_EPOCHS
we train it for here is 55, which is subject to change and should normally avoid both underfitting and overfitting.
N_EPOCHS = 55dcgan.fit(train_images, epochs=N_EPOCHS, callbacks=[DCGANMonitor()])
Note: Since the dataset we're using is quite large, this step can take a few hours to complete. Therefore, we'll be demonstrating our application by loading our saved model.
As the epoch number increases, our model starts making more accurate images. Let's take a look at a few random epoch outputs.
After completing the epoch number (55 in our case), our model can generate anime character images like the following ones below.
Note: Not every image generated by our GAN generator will be as accurate as the real dataset and some faces may have unexpected concerns such as eyes of different colors or out of proportion features.
Congratulations, our anime GAN has now been trained to create realistic and never seen anime character images! That is super cool and is a prime use case for GANs as well. Let's see how we can load our generator and discriminator and use it to generate images.
body { font-family: Arial, sans-serif; margin: 0; padding: 0; background-color: #f8f9fa; display: flex; align-items: center; justify-content: center; height: 100vh; } .container { text-align: center; padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); } img { max-width: 100%; border-radius: 5px; margin-top: 20px; } h1 { margin-top: 20px; color: #333; } p { color: #666; }
The following code simply recreates the DCGAN by loading the generator
and discriminator
and then using the compile
method on both. The create_dcgan
function is used to achieve this. We then render an HTML file using Python's Flask library to render the created anime character image on our server along with a "Generate" button. Each time we click this button, a new image will be generated for us!
We can keep clicking on generate until we find a picture of our choice.
It is recommended that you set up a Jupiter notebook and execute the code blocks one by one. After your model has been trained and the "h5" files are saved, you can load them anytime to create new images.
Test your knowledge of anime generation!
What two main components are required while building a GAN model for image generation?
Free Resources