Get Started with GANs for Image-to-Image Translation
An image domain is a set of images with a similar characteristics. For example, an image domain can be a group of images acquired in certain lighting conditions or images with a common set of noise distortions.
Image-to-image translation is the task of transferring styles and characteristics from one image domain to another. The source domain is the domain of the starting image. The target domain is the desired domain after translation. Applications of domain translation for three sample image domains include:
|Application||Source Domain||Target Domain|
|Day-to-dusk style conversion||Images acquired in the daytime||Images acquired at dusk|
|Image denoising||Images with noise distortion||Images without visible noise|
|Super-resolution||Low resolution images||High resolution images|
Select a GAN
You can perform image-to-image translation using deep learning generative adversarial networks (GANs). A GAN consists of a generator network and one or more discriminator networks that are trained simultaneously to maximize the overall performance. The objective of the generator network is to generate realistic images in the translated domain that cannot be distinguished from images in the original domain. The objective of discriminator networks is to correctly classify original training data as real and generator-synthesized images as fake.
The type of GAN depends on the training data.
Supervised GANs have a one-to-one mapping between images in the source and target domains. For an example, see Generate Image from Segmentation Map Using Deep Learning (Computer Vision Toolbox). In this example, the source domain consists of images captured of street scenes. The target domain consists of categorical images representing the semantic segmentation maps. The data set provides a ground truth segmentation map for every input training image.
Unsupervised GANs do not have a one-to-one mapping between images in the source and target domains. For an example, see Unsupervised Day-to-Dusk Image Translation Using UNIT. In this example, the source and target domains consist of images captured in daytime and dusk conditions, respectively. However, the scene content of the daytime and dusk images differs, so the daytime images do not have a corresponding dusk image with identical scene content.
Create GAN Networks
Image Processing Toolbox™ offers functions that enable you to create popular GAN networks. You can optionally modify the networks by changing properties such as the number of downsampling operations and the type of activation and normalization. The table describes the functions that enable you to create and modify GAN networks.
|Network||Creation and Modification Functions|
|pix2pixHD generator network |
A pix2pixHD GAN performs supervised learning. The network consists of a single generator and single discriminator.
Create a pix2pixHD generator network
|CycleGAN generator network |
A CycleGAN network performs unsupervised learning. The network consists of two generators and two discriminators. The first generator takes images from domain A and generates images in domain B. The corresponding discriminator takes images generated by the first generator and real images in domain B, and attempts to correctly classify the images as real and fake. Conversely, the second generator takes images from domain B and generates images in domain A. The corresponding discriminator takes images generated by the second generator and real images in domain A, and attempts to correctly classify the images as fake and real.
CycleGAN generator network using the
|UNIT generator network |
An unsupervised image-to-image translation (UNIT) GAN performs unsupervised learning. The network consists of one generator and two discriminators. The generator takes images in both domains, A and B. The generator returns four output images: two translated images (A-to-B and B-to-A), and two self-reconstructed images (A-to-A and B-to-B). The first discriminator takes a real and a generated image from domain A and returns the likelihood that the image is real. Similarly, the second discriminator takes a real and a generated image from domain B and returns the likelihood that the image is real.
Create a UNIT generator network using the
|PatchGAN discriminator network |
A PatchGAN discriminator network can serve as the discriminator network for pix2pixHD, CycleGAN, and UNIT GANs, as well as custom GANs.
Create a PatchGAN discriminator
network using the
You can also use the
Some networks require additional modification beyond the options available in the network creation functions. For example, you may want to replace the addition layers with depth concatenation layers, or you may want the initial leaky ReLU layer of a UNIT network to have a scale factor other than 0.2. To refine an existing GAN network, you can use Deep Network Designer (Deep Learning Toolbox). For more information, see Build Networks with Deep Network Designer (Deep Learning Toolbox).
If you need a network that is not available through the built-in creation
functions, then you can create custom GAN networks from modular components. First,
create the encoder and decoder modules, then combine the modules using the
encoderDecoderNetwork function. You can optionally include a bridge
connection, skip connections, or additional layers at the end of the network. For
more information, see Create Modular Neural Networks.
Train GAN Network
To train GAN generator and discriminator networks, you must use a custom training loop. There are several steps involved in preparing a custom training loop. For an example that shows the complete workflow, see Train Generative Adversarial Network (GAN) (Deep Learning Toolbox).
Create the generator and discriminator networks.
Create one or more datastores that read, preprocess, and augment training data. For more information, see Datastores for Deep Learning (Deep Learning Toolbox). Then, create a
minibatchqueue(Deep Learning Toolbox) object for each datastore that manages the mini-batching of observations in a custom training loop.
Define the model gradients function for each network. The function takes as input the network and a mini-batch of input data, and returns the gradients of the loss. Optionally, you can pass extra arguments to the gradients function (for example, if the loss function requires extra information), or return extra arguments (for example, the loss values). For more information, see Define Model Loss Function for Custom Training Loop (Deep Learning Toolbox).
Define the loss functions. Certain types of loss functions are commonly used for image-to-image translation applications, although the implementation of each loss can vary.
Adversarial loss is commonly used by generator and discriminator networks. This loss relies on the pixelwise or patchwise difference between the correct classification and the predicted classification by the discriminator.
Cycle consistency loss is commonly used by unsupervised generator networks. This loss is based on the principle that an image translated from one domain to another, then back to the original domain, should be identical to the original image.
Specify training options such as the solver type and the number of epochs. For more information, see Specify Training Options in Custom Training Loop (Deep Learning Toolbox).
Create the custom training loop that loops over mini-batches in every epoch. The loop reads each mini-batch of data, evaluates the model gradients using the
dlfeval(Deep Learning Toolbox) function, and updates the network parameters.
Optionally, include display functions such as plots of scores or batches of generated images that enable you to monitor the training progress. For more information, see Monitor GAN Training Progress and Identify Common Failure Modes (Deep Learning Toolbox).
- Unsupervised Day-to-Dusk Image Translation Using UNIT
- Generate Image from Segmentation Map Using Deep Learning (Computer Vision Toolbox)
- Create Modular Neural Networks
- Train Generative Adversarial Network (GAN) (Deep Learning Toolbox)
- Define Custom Training Loops, Loss Functions, and Networks (Deep Learning Toolbox)
- Define Model Loss Function for Custom Training Loop (Deep Learning Toolbox)
- Specify Training Options in Custom Training Loop (Deep Learning Toolbox)
- Train Network Using Custom Training Loop (Deep Learning Toolbox)