dlupdate

Update parameters using custom function

Description

example

dlnet = dlupdate(fun,dlnet) updates the learnable parameters of the dlnetwork object dlnet by evaluating the function fun with each learnable parameter as an input. fun is a function handle to a function that takes one parameter array as an input argument and returns an updated parameter array.

params = dlupdate(fun,params) updates the learnable parameters in params by evaluating the function fun with each learnable parameter as an input.

[___] = dlupdate(fun,___A1,...,An) also specifies additional input arguments, in addition to the input arguments in previous syntaxes, when fun is a function handle to a function that requires n+1 input values.

[___,X1,...,Xm] = dlupdate(fun,___) returns multiple outputs X1,...,Xm when fun is a function handle to a function that returns m+1 output values.

Examples

collapse all

Perform L1 regularization on a structure of parameter gradients.

Create the sample input data.

dlX = dlarray(rand(100,100,3),'SSC');

Initialize the learnable parameters for the convolution operation.

params.Weights = dlarray(rand(10,10,3,50));
params.Bias = dlarray(rand(50,1));

Calculate the gradients for the convolution operation using the helper function convGradients, defined at the end of this example.

grads = dlfeval(@convGradients,dlX,params);

Define the regularization factor.

L1Factor = 0.001;

Create an anonymous function that regularizes the gradients. By using an anonymous function to pass a scalar constant to the function, you can avoid having to expand the constant value to the same size and structure as the parameter variable.

L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);

Use dlupdate to apply the regularization function to each of the gradients.

grads = dlupdate(L1Regularizer,grads,params);

The gradients in grads are now regularized according to the function L1Regularizer.

convGradients Function

The convGradients helper function takes the learnable parameters of the convolution operation and a mini-batch of input data dlX, and returns the gradients with respect to the learnable parameters.

function grads = convGradients(dlX,params)
dlY = dlconv(dlX,params.Weights,params.Bias);
dlY = sum(dlY,'all');
grads = dlgradient(dlY,params);
end

Use dlupdate to train a network using a custom update function that implements the stochastic gradient descent algorithm (without momentum).

Load Training Data

Load the digits training data.

[XTrain,YTrain] = digitTrain4DArrayData;
classes = categories(YTrain);
numClasses = numel(classes);

Define the Network

Define the network architecture and specify the average image using the 'Mean' option in the image input layer.

layers = [
    imageInputLayer([28 28 1], 'Name','input','Mean',mean(XTrain,4))
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name', 'relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];
lgraph = layerGraph(layers);

Create a dlnetwork object from the layer graph.

dlnet = dlnetwork(lgraph);

Define the Model Gradients Function

Create the function modelGradients, listed at the end of this example, that takes a dlnetwork object dlnet and a mini-batch of input data dlX with corresponding labels Y, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet.

Define the Stochastic Gradient Descent Function

Create the function sgdFunction, listed at the end of this example, that takes param and paramGradient, a learnable parameter and the gradient of the loss with respect to that parameter, respectively. The function returns the updated parameter using the stochastic gradient descent algorithm, expressed as

θl+1=θ-αE(θl)

where l is the iteration number, α>0 is the learning rate, θ is the parameter vector, and E(θ) is the loss function.

Specify Training Options

Specify the options to use during training.

miniBatchSize = 128;
numEpochs = 20;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.

executionEnvironment = "auto";

Initialize the velocity parameter.

learnRate = 0.001;

Initialize the training progress plot.

plots = "training-progress";
if plots == "training-progress"
    iteration = 1;
    figure
    lineLossTrain = animatedline;
    xlabel("Total Iterations")
    ylabel("Loss")
end

Train the Network

Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters by calling dlupdate with the function sgdFunction defined at the end of this example. At the end of each epoch, display the training progress.

for epoch = 1:numEpochs
    % Shuffle data.
    idx = randperm(numel(YTrain));
    XTrain = XTrain(:,:,:,idx);
    YTrain = YTrain(idx);
    
    for i = 1:numIterationsPerEpoch
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        X = XTrain(:,:,:,idx);
        
        Y = zeros(numClasses, miniBatchSize, 'single');
        for c = 1:numClasses
            Y(c,YTrain(idx)==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [grad,loss] = dlfeval(@modelGradients,dlnet,dlX,Y);
        
        % Update the network parameters using the SGD algorithm defined in the 
        % function sgdFunction.
        dlnet = dlupdate(@sgdFunction,dlnet,grad);
        
        % Display the training progress.
        if plots == "training-progress"
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            title("Loss During Training: Epoch - " + epoch + "; Iteration - " + i)
            drawnow
            iteration = iteration + 1;
        end
    end
end

Test the Network

Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.

[XTest, YTest] = digitTest4DArrayData;

Convert the data to a dlarray object with dimension format 'SSCB'. For GPU prediction, also convert the data to gpuArray.

dlXTest = dlarray(XTest,'SSCB');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlXTest = gpuArray(dlXTest);
end

To classify images using a dlnetwork object, use the predict function and find the classes with the highest scores.

dlYPred = predict(dlnet,dlXTest);
[~,idx] = max(extractdata(dlYPred),[],1);
YPred = classes(idx);

Evaluate the classification accuracy.

accuracy = mean(YPred==YTest)
accuracy = 0.7282

Model Gradients Function

The function modelGradients takes a dlnetwork object dlnet and a mini-batch of input data dlX with corresponding labels Y, and returns the loss and the gradients of the loss with respect to the learnable parameters in dlnet. To compute the gradients automatically, use the dlgradient function.

function [gradients,loss] = modelGradients(dlnet,dlX,Y)
    dlYPred = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    loss = crossentropy(dlYPred,Y);
    gradients = dlgradient(loss,dlnet.Learnables);
end

Stochastic Gradient Descent Function

The function sgdFunction takes param and paramGradient, a learnable parameter and the gradient of the loss with respect to that parameter, and returns the updated parameter using the stochastic gradient descent algorithm, expressed as

θl+1=θ-αE(θl)

where l is the iteration number, α>0 is the learning rate, θ is the parameter vector, and E(θ) is the loss function.

function param = sgdFunction(param,paramGradient)
    learnRate = 0.01;
    param = param - learnRate.*paramGradient;
end

Input Arguments

collapse all

Function to apply to the learnable parameters, specified as a function handle.

dlupate evaluates fun with each network learnable parameter as an input. fun is evaluated as many times as there are arrays of learnable parameters in dlnet or params.

Network, specified as a dlnetwork object.

The function updates the dlnet.Learnables property of the dlnetwork object. dlnet.Learnables is a table with three variables:

  • Layer — Layer name, specified as a string scalar.

  • Parameter — Parameter name, specified as a string scalar.

  • Value — Value of parameter, specified as a cell array containing a dlarray.

Network learnable parameters, specified as a dlarray, a numeric array, a cell array, a structure, or a table.

If you specify params as a table, it must contain the following three variables.

  • Layer — Layer name, specified as a string scalar.

  • Parameter — Parameter name, specified as a string scalar.

  • Value — Value of parameter, specified as a cell array containing a dlarray.

You can specify params as a container of learnable parameters for your network using a cell array, structure, or table, or nested cell arrays or structures. The learnable parameters inside the cell array, structure, or table must be dlarray or numeric values of data type double or single.

The input argument grad must be provided with exactly the same data type, ordering, and fields (for structures) or variables (for tables) as params.

Data Types: single | double | struct | table | cell

Additional input arguments to fun, specified as dlarray objects, numeric arrays, cell arrays, structures, or tables with a Value variable.

The exact form of A1,...,An depends on the input network or learnable parameters. The following table shows the required format for A1,...,An for possible inputs to dlupdate.

InputLearnable ParametersA1,...,An
dlnetTable dlnet.Learnables containing Layer, Parameter, and Value variables. The Value variable consists of cell arrays that contain each learnable parameter as a dlarray. Table with the same data type, variables, and ordering as dlnet.Learnables. A1,...,An must have a Value variable consisting of cell arrays that contain the additional input arguments for the function fun to apply to each learnable parameter.
paramsdlarraydlarray with the same data type and ordering as params
Numeric arrayNumeric array with the same data type and ordering as params
Cell arrayCell array with the same data types, structure, and ordering as params
StructureStructure with the same data types, fields, and ordering as params
Table with Layer, Parameter, and Value variables. The Value variable must consist of cell arrays that contain each learnable parameter as a dlarray.Table with the same data types, variables and ordering as params. A1,...,An must have a Value variable consisting of cell arrays that contain the additional input argument for the function fun to apply to each learnable parameter.

Output Arguments

collapse all

Network, returned as a dlnetwork object.

The function updates the dlnet.Learnables property of the dlnetwork object.

Updated network learnable parameters, returned as a dlarray, a numeric array, a cell array, a structure, or a table with a Value variable containing the updated learnable parameters of the network.

Additional output arguments from the function fun, where fun is a function handle to a function that returns multiple outputs, returned as dlarray objects, numeric arrays, cell arrays, structures, or tables with a Value variable.

The exact form of X1,...,Xm depends on the input network or learnable parameters. The following table shows the returned format of X1,...,Xm for possible inputs to dlupdate.

InputLearnable parametersX1,...,Xm
dlnetTable dlnet.Learnables containing Layer, Parameter, and Value variables. The Value variable consists of cell arrays that contain each learnable parameter as a dlarray. Table with the same data type, variables, and ordering as dlnet.Learnables. X1,...,Xm has a Value variable consisting of cell arrays that contain the additional output arguments of the function fun applied to each learnable parameter.
paramsdlarraydlarray with the same data type and ordering as params
Numeric arrayNumeric array with the same data type and ordering as params
Cell arrayCell array with the same data types, structure, and ordering as params
StructureStructure with the same data types, fields, and ordering as params
Table with Layer, Parameter, and Value variables. The Value variable must consist of cell arrays that contain each learnable parameter as a dlarray.Table with the same data types, variables. and ordering as params. X1,...,Xm has a Value variable consisting of cell arrays that contain the additional output argument of the function fun applied to each learnable parameter.

Introduced in R2019b