Main Content

Train Deep Learning Model in MATLAB

You can train and customize a deep learning model in various ways—for example, you can retrain a pretrained model with new data (transfer learning), train a network from scratch, or define a deep learning model as a function and use a custom training loop. Use this flow chart to choose the training method that is best suited for your task.

Flowchart showing decision process of training methods.

Tip

For information on computer vision workflows, including for object detection, see Computer Vision. For information on importing networks and network architectures from TensorFlow™-Keras, Caffe, and the ONNX™ (Open Neural Network Exchange) model format, see Pretrained Networks from External Platforms.

Training Methods

This table provides information about the different training methods.

MethodMore Information
Use network directly

If a pretrained network already performs the task you require, then you do not need to retrain the network. Instead, you can make predictions with the network directly by using the minibatchpredict and predict functions. To convert classification scores to labels, use the scores2label function.

For an example, see Classify Image Using GoogLeNet.

Train network using trainet

If you have a network specified as a layer array or dlnetwork object, and the trainingOptions function provides all the options you need, then you can train the network using the trainnet function.

For an example showing how to retrain a network (transfer learning), see Retrain Neural Network to Classify New Images. For an example showing how to train a network from scratch, see Create Simple Deep Learning Neural Network for Classification.

Train network using custom training loop

For most tasks, you can control the training algorithm details using the trainingOptions and trainnet functions. If the trainingOptions function does not provide the options you need for your task (for example, a custom solver), then you can define your own custom training loop.

For loss functions that cannot be specified using an function handle or for models that cannot be specified as networks of layers, you can train the model using a custom training loop.

For an example showing how to train a neural network using a custom training loop, see Train Network Using Custom Training Loop.

For models that cannot be specified as a network of layers, you can define the model as a function. For an example showing how to train a deep learning model defined as a function, see Train Network Using Model Function.

To learn, more see Define Custom Training Loops, Loss Functions, and Networks.

Decisions

This table provides more information on each decision in the flow chart.

DecisionMore Information
Does the software provide a suitable pretrained network?

For most tasks, you can use or retrain a pretrained network such as SqueezeNet.

For a list of pretrained deep learning networks in MATLAB®, see Pretrained Deep Neural Networks. You can use pretrained networks directly with new data, or you can retrain them with new data for different tasks using transfer learning.

Can you use the network without retraining?

If a pretrained network already performs the task you need, then you can use the network directly without retraining. For example, you can use a pretrained SqueezeNet neural network to classify images in 1000 classes. To load a pretrained SqueezeNet neural network, use imagePretrainedNetwork. To make predictions with the network directly, use the minibatchpredict or predict functions, and then use the scores2label function. For an example, see Classify Image Using GoogLeNet.

If you need to retrain the network—for example, to classify a different set of classes—then you can retrain the network using transfer learning. For an example, see Retrain Neural Network to Classify New Images.

Can you define the model as a layer array or a network of layers?

You can specify most deep learning models as a layer array or a network of layers. In other words, you can define the model as a collection of layers with layer outputs connected to other layer inputs.

Some models cannot be defined as a neural network. For example, twin networks require weight sharing and cannot be defined as a neural network. For these models, you must define the model as a function. For an example, see Train Network Using Model Function.

Does the software provide the layers you need?

Deep Learning Toolbox™ provides many different layers for deep learning tasks. For a list of layers, see List of Deep Learning Layers.

If the software provides the layers that you need, then you can define them as an array or a neural network of these layers. Otherwise, try defining any unsupported layers as custom layers. For more information, see Define Custom Deep Learning Layers.

Can you define the unsupported layers as custom layers?

If the software does not provide the layer you need, then you can try defining a custom deep learning layer. For more information, see Define Custom Deep Learning Layers.

If you can define custom layers for any unsupported layers, then you can include these custom layers in a layer array or neural network. Otherwise, specify the deep learning model using a function and train the model using a custom training loop. For an example, see Train Network Using Model Function.

Does the trainnet function provide the loss function you need?

The trainnet function provides different loss functions for deep learning tasks. For example, "crossentropy" and "mse". For more information, see the loss function argument of the trainnet function.

If the trainnet function does not provide the loss function that you need, then try defining the loss function as a function handle or train the model using a custom training loop.

Can you define the loss function as a function handle?

If the trainnet function does not provide the loss function you need, then you can try defining a custom loss function as a function handle. For more information, see the loss function argument of the trainnet function.

If you can define the loss function as a function handle, then you can use this as the loss function in the trainnet function. Otherwise, train the model using a dlnetwork object and a custom training loop. For an example, see Train Network Using Custom Training Loop.

Does the trainingOptions function provide the options you need?

The trainingOptions function provides many options for customizing the training process. If the trainingOptions function provides all the options you need for training, then you can train the deep learning network using the trainNetwork function. For an example, see Create Simple Deep Learning Neural Network for Classification.

If the trainingOptions function does not provide the training option you need, for example, a custom solver, then you can define a custom training loop using a dlnetwork object. For an example, see Train Network Using Custom Training Loop.

See Also

| |

Related Topics