AI anime generation using GANs

A highly interesting application of generative adversarial networks is image generation. One popular and challenging use case of image generation is creating anime faces. It's a great project that one can build while getting started with GANs. So, let's dive straight ahead and build one ourselves!

Generative adversarial networks

GANs stand for generative adversarial networks and are basically a class of machine learning models that can generate new data similar to the current dataset.

They consist of two components called a generator and a discriminator.

Generator

Discriminator

The generator creates new data samples by trying to mimic the real data.


The discriminator tries to distinguish between our real and generator data.

GANs work by training both of these components in a competitive way. The generator aims to make such data that can trick the discriminator into believing that it is real, and the discriminator aims to classify real and generated data correctly.

As we progress through the training, the generator improves at generating realistic data, while the discriminator improves at distinguishing real from this generated data. At the end, the generator's results get very accurate.

 GAN architecture diagram
GAN architecture diagram

Imports

  • We begin by importing the required libraries and modules to access their functionalities in the code.

    • os is needed for interacting with the operating system e.g., reading files.

    • numpy is needed for numerical operations.

    • matplotlib is needed for plotting and creating visualizations.

    • random is needed for introducing randomness in data manipulation.

    • warnings is needed to avoid cluttering the output.

    • tensorflow is needed for all of the deep-learning aspects of the code.

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import warnings
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
warnings.filterwarnings('ignore')

Image loading and display

  • For training our model, we have used the anime dataset found on Kaggle. These images have been saved in data in the same directory.

  • In the following snippet, we iterate through the files within data and create a list called all_images. We store the paths of each image file in this list.

all_images = []
for image in os.listdir("data"):
image_path = os.path.join("data", image)
all_images.append(image_path)
  • To check if our data has been loaded properly, we display a few images from the all_images list. In this case, we show six of the images using Matplotlib's plt.subplot().

selected_images = random.sample(all_images, 6)
fig, ax = plt.subplots(1, 6, figsize=(35, 35))
for i, image_path in enumerate(selected_images):
img = mpimg.imread(image_path)
ax[i].imshow(img)
ax[i].axis('off')
plt.show()

The output will contain six images from the dataset, for instance:

Some anime images from the dataset
Some anime images from the dataset

Training image processing

  • We save the training images data in train_images by converting the images we just loaded above into arrays. These images are used to feed the relevant information to the model in order to train it.

  • In the following snippet, we save the image data, normalize the pixel values, and reshape the images into a consistent shape.

train_images = [img_to_array(load_img(path)) for path in all_images]
train_images = np.array(train_images, dtype='float32')
train_images = (train_images - 127.5) / 127.5
train_images = train_images.reshape(train_images.shape[0], 64, 64, 3)

Generator

generator = Sequential(name='generator')
generator.add(layers.Dense(8 * 8 * 512, input_dim=100))
generator.add(layers.ReLU())
generator.add(layers.Reshape((8, 8, 512)))
generator.add(layers.Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same', kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02)))
generator.add(layers.ReLU())
generator.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02)))
generator.add(layers.ReLU())
generator.add(layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02)))
generator.add(layers.ReLU())
generator.add(layers.Conv2D(3, (4, 4), padding='same', activation='tanh'))
generator.summary()

Discriminator

  • To train the model as to whether the images being created are realistic, we will now make the discriminator.

  • Read about it in detail in the discriminator for AI anime answer.

discriminator = Sequential(name='discriminator')
input_shape = (64, 64, 3)
discriminator.add(layers.Conv2D(64, (4, 4), strides=(2, 2), padding='same', input_shape=input_shape))
discriminator.add(layers.BatchNormalization())
discriminator.add(layers.LeakyReLU(alpha=0.2))
discriminator.add(layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same'))
discriminator.add(layers.BatchNormalization())
discriminator.add(layers.LeakyReLU(alpha=0.2))
discriminator.add(layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same'))
discriminator.add(layers.BatchNormalization())
discriminator.add(layers.LeakyReLU(alpha=0.2))
discriminator.add(layers.Flatten())
discriminator.add(layers.Dropout(0.3))
discriminator.add(layers.Dense(1, activation='sigmoid'))
discriminator.summary()

DCGAN class

  • We define our custom DCGAN class inheriting from keras.Model. This class shows the training process between a generator and a discriminator.

  • The __init__ constructor initializes the DCGAN class with the generator, discriminator, and latent dimension. Let's go through the important keywords here.

    • self.generator and self.discriminator: Store generator and discriminator models.

    • self.latent_dim: Holds the dimension of random noise for the generator.

    • self.discriminator_loss_metric and self.generator_loss_metric: Metrics to track losses.

  • The metrics property returns loss metrics for monitoring.

  • The compile method configures the optimizer and loss function for training.

  • The train_step method defines a training step for the DCGAN. Then it computes the batch size and generates random noise. It also uses gradient tape for differentiation.

  • Inside the gradient tape context, we calculate the following:

    • Discriminator loss for real images and labels with slight noise.

    • Discriminator loss for fake images.

    • Combined discriminator loss as the average.

    • Gradients of discriminator loss and update the discriminator's variables.

    • Generator loss using fake images and gradients.

    • Update the generator's variables.

  • Next, we update the discriminator and generator loss metrics.

  • We finally return the updated loss metrics for monitoring.

class DCGAN(keras.Model):
def __init__(self, generator, discriminator, latent_dim):
super(DCGAN, self).__init__()
self.generator = generator
self.discriminator = discriminator
self.latent_dim = latent_dim
self.discriminator_loss_metric = keras.metrics.Mean(name='discriminator_loss')
self.generator_loss_metric = keras.metrics.Mean(name='generator_loss')
@property
def metrics(self):
return [self.discriminator_loss_metric, self.generator_loss_metric]
def compile(self, g_opt, d_opt, loss_fn):
super(DCGAN, self).compile()
self.g_opt = g_opt
self.d_opt = d_opt
self.loss_fn = loss_fn
def train_step(self, real_imgs):
batch_size = tf.shape(real_imgs)[0]
noise = tf.random.normal(shape=(batch_size, self.latent_dim))
with tf.GradientTape() as tape:
pred_real = self.discriminator(real_imgs, training=True)
real_labels = tf.ones((batch_size, 1)) + 0.05 * tf.random.uniform((batch_size, 1))
discriminator_loss_real = self.loss_fn(real_labels, pred_real)
fake_imgs = self.generator(noise, training=True)
pred_fake = self.discriminator(fake_imgs, training=True)
fake_labels = tf.zeros((batch_size, 1))
discriminator_loss_fake = self.loss_fn(fake_labels, pred_fake)
discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2
gradients = tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
self.d_opt.apply_gradients(zip(gradients, self.discriminator.trainable_variables))
labels = tf.ones((batch_size, 1))
with tf.GradientTape() as tape:
fake_imgs = self.generator(noise, training=True)
pred_fake = self.discriminator(fake_imgs, training=True)
generator_loss = self.loss_fn(labels, pred_fake)
gradients = tape.gradient(generator_loss, self.generator.trainable_variables)
self.g_opt.apply_gradients(zip(gradients, self.generator.trainable_variables))
self.discriminator_loss_metric.update_state(discriminator_loss)
self.generator_loss_metric.update_state(generator_loss)
return {'discriminator_loss': self.discriminator_loss_metric.result(),
'generator_loss': self.generator_loss_metric.result()}

DCGANMonitor class

  • The DCGANMonitor class inherits from keras.callbacks.Callback. It initializes the callback with the number of images to generate num_imgs and the dimensions of the latent space latent_dim. self.noise is the random noise for image generation.

  • The on_epoch_end method basically gets the generated images from the generator, and after scaling them, it displays the plot using a 5 x 5 grid.

  • The on_train_end method saves the trained generator and discriminator as "generator.h5" and "discriminator.h5" so that we can use them again later on without spending a lot of time on training.

class DCGANMonitor(keras.callbacks.Callback):
def __init__(self, num_imgs=25, latent_dim=100):
self.num_imgs = num_imgs
self.latent_dim = latent_dim
self.noise = tf.random.normal([25, latent_dim])
def on_epoch_end(self, epoch, logs=None):
gen_imgs = self.model.generator(self.noise)
gen_imgs = (gen_imgs * 127.5) + 127.5
fig = plt.figure(figsize=(8, 8))
for i in range(self.num_imgs):
plt.subplot(5, 5, i+1)
img = array_to_img(gen_imgs[i])
plt.imshow(img)
plt.axis('off')
plt.show()
def on_train_end(self, logs=None):
self.model.generator.save('generator.h5')
self.model.discriminator.save('discriminator.h5')
dcgan = DCGAN(generator, discriminator, 100)

Model compilation

  • We finally compile our dcgan and specify the learning_rate, beta_1, and loss_fn values.

dcgan.compile(
g_opt=Adam(learning_rate=0.0003, beta_1=0.5),
d_opt=Adam(learning_rate=0.0001, beta_1=0.5),
loss_fn=BinaryCrossentropy()
)

Model fitting

  • The number of epochs N_EPOCHS we train it for here is 55, which is subject to change and should normally avoid both underfitting and overfitting.

N_EPOCHS = 55
dcgan.fit(train_images, epochs=N_EPOCHS, callbacks=[DCGANMonitor()])

Note: Since the dataset we're using is quite large, this step can take a few hours to complete. Therefore, we'll be demonstrating our application by loading our saved model.

A few epoch outputs

As the epoch number increases, our model starts making more accurate images. Let's take a look at a few random epoch outputs.

Epoch example
Epoch example
Epoch example
Epoch example
Epoch example
Epoch example

After completing the epoch number (55 in our case), our model can generate anime character images like the following ones below.

Anime faces output

Anime faces generated by our model
Anime faces generated by our model

Note: Not every image generated by our GAN generator will be as accurate as the real dataset and some faces may have unexpected concerns such as eyes of different colors or out of proportion features.

Flask application

Congratulations, our anime GAN has now been trained to create realistic and never seen anime character images! That is super cool and is a prime use case for GANs as well. Let's see how we can load our generator and discriminator and use it to generate images.

body {
    font-family: Arial, sans-serif;
    margin: 0;
    padding: 0;
    background-color: #f8f9fa;
    display: flex;
    align-items: center;
    justify-content: center;
    height: 100vh;
}

.container {
    text-align: center;
    padding: 20px;
    background-color: white;
    border-radius: 10px;
    box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}

img {
    max-width: 100%;
    border-radius: 5px;
    margin-top: 20px;
}

h1 {
    margin-top: 20px;
    color: #333;
}

p {
    color: #666;
}

The following code simply recreates the DCGAN by loading the generator and discriminator and then using the compile method on both. The create_dcgan function is used to achieve this. We then render an HTML file using Python's Flask library to render the created anime character image on our server along with a "Generate" button. Each time we click this button, a new image will be generated for us!

Application output

We can keep clicking on generate until we find a picture of our choice.

Flask output
Flask output

Project execution

It is recommended that you set up a Jupiter notebook and execute the code blocks one by one. After your model has been trained and the "h5" files are saved, you can load them anytime to create new images.

Test your knowledge of anime generation!

Question

What two main components are required while building a GAN model for image generation?

Show Answer

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved