Try Multiple Pretrained Networks for Transfer Learning
This example shows how to configure an experiment that replaces layers of different pretrained networks for transfer learning. Transfer learning is commonly used in deep learning applications. You can take a pretrained network and use it as a starting point to learn a new task. Fine-tuning a network with transfer learning is usually much faster and easier than training a network with randomly initialized weights from scratch. You can quickly transfer learned features to a new task using a smaller number of training images.
There are many pretrained networks available in Deep Learning Toolbox™. These pretrained networks have different characteristics that matter when choosing a network to apply to your problem. The most important characteristics are network accuracy, speed, and size. Choosing a network is generally a tradeoff between these characteristics. To compare the performance of different pretrained networks for your task, edit this experiment and specify which pretrained networks to use.
This experiment requires the Deep Learning Toolbox Model for GoogLeNet Network support package and the Deep Learning Toolbox Model for ResNet-18 Network support package. Before you run the experiment, install these support packages by calling the googlenet
and resnet18
functions and clicking the download links. For more information on other pretrained networks that you can download from the Add-On Explorer, see Pretrained Deep Neural Networks.
Open Experiment
First, open the example. Experiment Manager loads a project with a preconfigured experiment that you can inspect and run. To open the experiment, in the Experiment Browser pane, double-click TransferLearningExperiment.
Built-in training experiments consist of a description, a table of hyperparameters, a setup function, and a collection of metric functions to evaluate the results of the experiment. For more information, see Train Network Using trainnet and Display Custom Metrics.
The Description field contains a textual description of the experiment. For this example, the description is:
Perform transfer learning by replacing layers in a pretrained network.
The Hyperparameters section specifies the strategy and hyperparameter values to use for the experiment. When you run the experiment, Experiment Manager trains the network using every combination of hyperparameter values specified in the hyperparameter table. In this example, the hyperparameter NetworkName
specifies the network to train and the value of the training option miniBatchSize
.
The Setup Function section specifies a function that configures the training data, network architecture, and training options for the experiment. To open this function in MATLAB® Editor, click Edit. The code for the function also appears in Setup Function. The input to the setup function is a structure with fields from the hyperparameter table. The function returns three outputs that you use to train a network for image classification problems. In this example, the setup function:
Loads a pretrained network corresponding to the hyperparameter
NetworkName
.
networkName = params.NetworkName; switch networkName case "squeezenet" net = squeezenet; miniBatchSize = 128; case "googlenet" net = googlenet; miniBatchSize = 128; case "resnet18" net = resnet18; miniBatchSize = 128; case "mobilenetv2" net = mobilenetv2; miniBatchSize = 128; case "resnet50" net = resnet50; miniBatchSize = 128; case "resnet101" net = resnet101; miniBatchSize = 64; case "inceptionv3" net = inceptionv3; miniBatchSize = 64; case "inceptionresnetv2" net = inceptionresnetv2; miniBatchSize = 64; otherwise error("Undefined network selection.");end
Downloads and extracts the Flowers data set, which is about 218 MB. For more information on this data set, see Image Data Sets.
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz"); imageFolder = fullfile(downloadFolder,"flower_photos"); if ~exist(imageFolder,"dir") disp("Downloading Flower Dataset (218 MB)...") websave(filename,url); untar(filename,downloadFolder) end imds = imageDatastore(imageFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.9); inputSize = net.Layers(1).InputSize; augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain); augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);
Replaces the learnable layers of the pretrained network to perform transfer learning. The helper function
findLayersToReplace
determines the layers in the network architecture that can be modified for transfer learning. To view the code for this function, see Find Layers to Replace. For more information on the available pretrained networks, see Pretrained Deep Neural Networks.
lgraph = layerGraph(net); [learnableLayer,classLayer] = findLayersToReplace(lgraph); numClasses = numel(categories(imdsTrain.Labels)); if isa(learnableLayer,"nnet.cnn.layer.FullyConnectedLayer") newLearnableLayer = fullyConnectedLayer(numClasses, ... Name="new_fc", ... WeightLearnRateFactor=10, ... BiasLearnRateFactor=10); elseif isa(learnableLayer,"nnet.cnn.layer.Convolution2DLayer") newLearnableLayer = convolution2dLayer(1,numClasses, ... Name="new_conv", ... WeightLearnRateFactor=10, ... BiasLearnRateFactor=10); end lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer); newClassLayer = classificationLayer(Name="new_classoutput"); lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
Defines a
trainingOptions
object for the experiment. The example trains the network for 10 epochs, using an initial learning rate of 0.0003 and validating the network every 5 epochs.
validationFrequencyEpochs = 5; numObservations = augimdsTrain.NumObservations; numIterationsPerEpoch = floor(numObservations/miniBatchSize); validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch; options = trainingOptions("sgdm", ... MaxEpochs=10, ... MiniBatchSize=miniBatchSize, ... InitialLearnRate=3e-4, ... Shuffle="every-epoch", ... ValidationData=augimdsValidation, ... ValidationFrequency=validationFrequency, ... Verbose=false);
The Metrics section specifies optional functions that evaluate the results of the experiment. This example does not include any custom metric functions.
Run Experiment
When you run the experiment, Experiment Manager trains the network defined by the setup function six times. Each trial uses a different combination of hyperparameter values. By default, Experiment Manager runs one trial at a time. If you have Parallel Computing Toolbox™, you can run multiple trials at the same time or offload your experiment as a batch job in a cluster:
To run one trial of the experiment at a time, on the Experiment Manager toolstrip, set Mode to
Sequential
and click Run.To run multiple trials at the same time, set Mode to
Simultaneous
and click Run. If there is no current parallel pool, Experiment Manager starts one using the default cluster profile. Experiment Manager then runs as many simultaneous trials as there are workers in your parallel pool. For best results, before you run your experiment, start a parallel pool with as many workers as GPUs. For more information, see Run Experiments in Parallel and GPU Computing Requirements (Parallel Computing Toolbox).To offload the experiment as a batch job, set Mode to
Batch Sequential
orBatch Simultaneous
, specify your cluster and pool size, and click Run. For more information, see Offload Experiments as Batch Jobs to a Cluster.
A table of results displays the accuracy and loss for each trial.
When the experiment finishes, you can sort the results table by column, filter trials by using the Filters pane, or record observations by adding annotations.
To test the performance of an individual trial, export the trained network or the training information for the trial. On the Experiment Manager toolstrip, select Export > Trained Network or Export > Training Information, respectively. For more information, see net and info. To save the contents of the results table as a nested table
array in the MATLAB workspace, select Export > Results Table.
Close Experiment
In the Experiment Browser pane, right-click FlowerTransferLearningProject and select Close Project. Experiment Manager closes the experiment and results contained in the project.
Setup Function
This function configures the training data, network architecture, and training options for the experiment. The input to this function is a structure with fields from the hyperparameter table. The function returns three outputs that you use to train a network for image classification problems.
function [augimdsTrain,lgraph,options] = TransferLearningExperiment_setup(params)
Load Pretrained Network
networkName = params.NetworkName; switch networkName case "squeezenet" net = squeezenet; miniBatchSize = 128; case "googlenet" net = googlenet; miniBatchSize = 128; case "resnet18" net = resnet18; miniBatchSize = 128; case "mobilenetv2" net = mobilenetv2; miniBatchSize = 128; case "resnet50" net = resnet50; miniBatchSize = 128; case "resnet101" net = resnet101; miniBatchSize = 64; case "inceptionv3" net = inceptionv3; miniBatchSize = 64; case "inceptionresnetv2" net = inceptionresnetv2; miniBatchSize = 64; otherwise error("Undefined network selection.");end
Load Training Data
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz"); imageFolder = fullfile(downloadFolder,"flower_photos"); if ~exist(imageFolder,"dir") disp("Downloading Flower Dataset (218 MB)...") websave(filename,url); untar(filename,downloadFolder) end imds = imageDatastore(imageFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.9); inputSize = net.Layers(1).InputSize; augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain); augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);
Define Network Architecture
lgraph = layerGraph(net); [learnableLayer,classLayer] = findLayersToReplace(lgraph); numClasses = numel(categories(imdsTrain.Labels)); if isa(learnableLayer,"nnet.cnn.layer.FullyConnectedLayer") newLearnableLayer = fullyConnectedLayer(numClasses, ... Name="new_fc", ... WeightLearnRateFactor=10, ... BiasLearnRateFactor=10); elseif isa(learnableLayer,"nnet.cnn.layer.Convolution2DLayer") newLearnableLayer = convolution2dLayer(1,numClasses, ... Name="new_conv", ... WeightLearnRateFactor=10, ... BiasLearnRateFactor=10); end lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer); newClassLayer = classificationLayer(Name="new_classoutput"); lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
Specify Training Options
validationFrequencyEpochs = 5; numObservations = augimdsTrain.NumObservations; numIterationsPerEpoch = floor(numObservations/miniBatchSize); validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch; options = trainingOptions("sgdm", ... MaxEpochs=10, ... MiniBatchSize=miniBatchSize, ... InitialLearnRate=3e-4, ... Shuffle="every-epoch", ... ValidationData=augimdsValidation, ... ValidationFrequency=validationFrequency, ... Verbose=false);
end
Find Layers to Replace
This function finds the single classification layer and the preceding learnable (fully connected or convolutional) layer of the layer graph lgraph
.
function [learnableLayer,classLayer] = findLayersToReplace(lgraph) if ~isa(lgraph,"nnet.cnn.LayerGraph") error("Argument must be a LayerGraph object.") end src = string(lgraph.Connections.Source); dst = string(lgraph.Connections.Destination); layerNames = string({lgraph.Layers.Name}'); isClassificationLayer = arrayfun(@(l) ... (isa(l,"nnet.cnn.layer.ClassificationOutputLayer")|isa(l,"nnet.layer.ClassificationLayer")), ... lgraph.Layers); if sum(isClassificationLayer) ~= 1 error("Layer graph must have a single classification layer.") end classLayer = lgraph.Layers(isClassificationLayer); currentLayerIdx = find(isClassificationLayer); while true if numel(currentLayerIdx) ~= 1 error("Layer graph must have a single learnable layer preceding the classification layer.") end currentLayerType = class(lgraph.Layers(currentLayerIdx)); isLearnableLayer = ismember(currentLayerType, ... ["nnet.cnn.layer.FullyConnectedLayer","nnet.cnn.layer.Convolution2DLayer"]); if isLearnableLayer learnableLayer = lgraph.Layers(currentLayerIdx); return end currentDstIdx = find(layerNames(currentLayerIdx) == dst); currentLayerIdx = find(src(currentDstIdx) == layerNames); end end
See Also
Apps
Functions
googlenet
|resnet18
|squeezenet
|trainNetwork
|trainingOptions
|table