updatePrunables
Syntax
Description
removes up to prunableNet_new = updatePrunables(prunableNet)8 prunable filters from the convolution layers of
prunableNet and returns an updated TaylorPrunableNetwork object. This function removes filters that have the lowest
Taylor-based importance scores.
To prune a deep neural network, you require the Deep Learning Toolbox™ Model Compression Library support package. This support package is a free add-on that you can download using the Add-On Explorer. Alternatively, see Deep Learning Toolbox Model Compression Library.
removes up to prunableNet_new = updatePrunables(prunableNet,MaxToPrune=maxToPrune)maxToPrune prunable filters from the convolution layers of
prunableNet and returns an updated TaylorPrunableNetwork object.
Examples
This example shows how to prune a dlnetwork object by using a custom pruning loop.
Load dlnetwork Object
Load a trained dlnetwork object and the corresponding classes.
s = load("digitsCustom.mat");
dlnet_1 = s.dlnet;
classes = s.classes;Inspect the layers of the dlnetwork object. The network has three convolution layers at locations 2, 5, and 8 of the Layer array.
layers_1 = dlnet_1.Layers
layers_1 =
12×1 Layer array with layers:
1 'input' Image Input 28×28×1 images with 'zerocenter' normalization
2 'conv1' 2-D Convolution 20 5×5×1 convolutions with stride [1 1] and padding [0 0 0 0]
3 'bn1' Batch Normalization Batch normalization with 20 channels
4 'relu1' ReLU ReLU
5 'conv2' 2-D Convolution 20 3×3×20 convolutions with stride [1 1] and padding [1 1 1 1]
6 'bn2' Batch Normalization Batch normalization with 20 channels
7 'relu2' ReLU ReLU
8 'conv3' 2-D Convolution 20 3×3×20 convolutions with stride [1 1] and padding [1 1 1 1]
9 'bn3' Batch Normalization Batch normalization with 20 channels
10 'relu3' ReLU ReLU
11 'fc' Fully Connected 10 fully connected layer
12 'softmax' Softmax softmax
Load Data for Prediction
Load the digits data for prediction.
dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset"); imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames");
Partition the data into pruning and validation sets. Set aside 10% of the data for validation using the splitEachLabel function.
[imdsPrune,imdsValidation] = splitEachLabel(imds,0.9,"randomize");The network used in this example requires input images of size 28-by-28-by-1. To automatically resize the images, use augmented image datastores.
inputSize = [28 28 1]; augimdsPrune = augmentedImageDatastore(inputSize(1:2),imdsPrune); augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Prune dlnetwork Object
Convert the dlnetwork object to a representation that is suitable for pruning by using the taylorPrunableNetwork function. This function returns a TaylorPrunableNetwork object that has the NumPrunables property set to 48. This indicates that 48 filters in the original model are suitable for pruning by using the Taylor pruning algorithm.
prunableNet_1 = taylorPrunableNetwork(dlnet_1)
prunableNet_1 =
TaylorPrunableNetwork with properties:
Learnables: [14×3 table]
State: [6×3 table]
InputNames: {'input'}
OutputNames: {'softmax'}
NumPrunables: 48
Create a minibatchqueue object that processes and manages mini-batches of images during pruning. For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch(defined at the end of this example) to convert the labels to one-hot encoded variables.Format the image data with the dimension labels
"SSCB"(spatial, spatial, channel, batch). By default, theminibatchqueueobject converts the data todlarrayobjects with underlying typesingle. Do not format the class labels.Train on a GPU if one is available. By default, the
minibatchqueueobject converts each output to agpuArrayif a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(augimdsPrune, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@preprocessMiniBatch, ... MiniBatchFormat=["SSCB" ""]);
Calculate Taylor-based importance scores of the prunable filters in the network by looping over the mini-batches of data. For each mini-batch:
Calculate pruning activations and pruning gradients by using the
modelLossfunction defined at the end of this exampleUpdate importance scores of the prunable filters by using the
updateScorefunction
while hasdata(mbq) [X,T] = next(mbq); [~,pruningActivations,pruningGradients] = dlfeval(@modelLoss,prunableNet_1,X,T); prunableNet_1 = updateScore(prunableNet_1,pruningActivations,pruningGradients); end
Finally, remove filters with the lowest importance scores to create a new TaylorPrunableNetwork object by using the updatePrunables function. By default, a single call to this function removes 8 filters. Observe that the new network prunableNet_2 has 40 prunable filters remaining.
prunableNet_2 = updatePrunables(prunableNet_1)
prunableNet_2 =
TaylorPrunableNetwork with properties:
Learnables: [14×3 table]
State: [6×3 table]
InputNames: {'input'}
OutputNames: {'softmax'}
NumPrunables: 40
To further compress the model, run the custom pruning loop and update prunables again.
Extract Pruned dlnetwork Object
Use the dlnetwork function to extract the pruned dlnetwork object from the pruned TaylorPrunableNetwork object. You can now use this compressed dlnetwork object to perform inference.
dlnet_2 = dlnetwork(prunableNet_2);
Compare the convolution layers of the original and the pruned dlnetwork objects. Observe that the three convolution layers in the pruned network have fewer filters. These counts agree with the fact that, by default, a single call to the updatePrunables function removes 8 filters from the network.
conv_layers_1 = dlnet_1.Layers([2 5 8])
conv_layers_1 =
3×1 Convolution2DLayer array with layers:
1 'conv1' 2-D Convolution 20 5×5×1 convolutions with stride [1 1] and padding [0 0 0 0]
2 'conv2' 2-D Convolution 20 3×3×20 convolutions with stride [1 1] and padding [1 1 1 1]
3 'conv3' 2-D Convolution 20 3×3×20 convolutions with stride [1 1] and padding [1 1 1 1]
conv_layers_2 = dlnet_2.Layers([2 5 8])
conv_layers_2 =
3×1 Convolution2DLayer array with layers:
1 'conv1' 2-D Convolution 17 5×5×1 convolutions with stride [1 1] and padding [0 0 0 0]
2 'conv2' 2-D Convolution 18 3×3×17 convolutions with stride [1 1] and padding [1 1 1 1]
3 'conv3' 2-D Convolution 17 3×3×18 convolutions with stride [1 1] and padding [1 1 1 1]
Supporting Functions
Model Loss Function
The modelLoss function takes a TaylorPrunableNetwork object net, a mini-batch of input data X with corresponding targets T and returns activations in net and the gradients of the loss with respect to the activations in net. To compute the gradients automatically, this function uses the dlgradient function.
function [loss, pruningActivations, pruningGradients] = modelLoss(net,X,T) % Calculate network output for training. [out, ~, pruningActivations] = forward(net,X); % Calculate loss. loss = crossentropy(out,T); % Compute pruning gradients. pruningGradients = dlgradient(loss,pruningActivations); end
Mini Batch Preprocessing Function
The preprocessMiniBatch function preprocesses a mini-batch of predictors and labels using the following steps:
Preprocess the images using the
preprocessMiniBatchPredictorsfunction.Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.
One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
function [X,T] = preprocessMiniBatch(dataX,dataT) % Preprocess predictors. X = preprocessMiniBatchPredictors(dataX); % Extract label data from cell and concatenate. T = cat(2,dataT{1:end}); % One-hot encode labels. T = onehotencode(T,1); end
Mini-Batch Predictors Preprocessing Function
The preprocessMiniBatchPredictors function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenating into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension.
function X = preprocessMiniBatchPredictors(dataX) % Concatenate. X = cat(4,dataX{1:end}); % Normalize the images. X = X/255; end
Input Arguments
Network for pruning by using first-order Taylor approximation, specified as a
TaylorPrunableNetwork object.
Maximum number of filters to be pruned, specified as a numeric integer-valued scalar
Output Arguments
Network object from which filters with low Taylor-based importance scores have been
removed, specified as a TaylorPrunableNetwork object.
Version History
Introduced in R2022a
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Website auswählen
Wählen Sie eine Website aus, um übersetzte Inhalte (sofern verfügbar) sowie lokale Veranstaltungen und Angebote anzuzeigen. Auf der Grundlage Ihres Standorts empfehlen wir Ihnen die folgende Auswahl: .
Sie können auch eine Website aus der folgenden Liste auswählen:
So erhalten Sie die bestmögliche Leistung auf der Website
Wählen Sie für die bestmögliche Website-Leistung die Website für China (auf Chinesisch oder Englisch). Andere landesspezifische Websites von MathWorks sind für Besuche von Ihrem Standort aus nicht optimiert.
Amerika
- América Latina (Español)
- Canada (English)
- United States (English)
Europa
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)