How to use Nadam optimizer in training deep neural networks
8 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
kollikonda Ashok kumar
am 29 Mär. 2023
Kommentiert: Amanjit Dulai
am 25 Okt. 2024 um 11:06
Training_Options = trainingOptions('sgdm', ...
'MiniBatchSize', 32, ...
'MaxEpochs', 50, ...
"InitialLearnRate", 1e-5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', Resized_Validation_Data, ...
'ValidationFrequency', 40, ...
"ExecutionEnvironment","gpu",...
'Plots','training-progress', ...
'Verbose',false);
1 Kommentar
Antworten (2)
Nayan
am 5 Apr. 2023
Hi
I assume you want to use "adam" optimizer in place "sgdm". You need to simply replace the "sgdm" key with "adam" keyword.
options = trainingOptions("adam", ...
InitialLearnRate=3e-4, ...
SquaredGradientDecayFactor=0.99, ...
MaxEpochs=20, ...
MiniBatchSize=64, ...
Plots="training-progress")
0 Kommentare
Amanjit Dulai
am 25 Okt. 2024 um 11:02
You can train with Nadam by defining a custom training loop. The function dlupdate can be used to define custom update rules for training. The rules for Nadam are shown below:
where the momentum is given by:
Below is an example of how to train a digit classification network using Nadam in a custom training loop:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
1 Kommentar
Amanjit Dulai
am 25 Okt. 2024 um 11:06
Also, if you want to use weight decay only on the weights, you can modify the example as shown below:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply weight regulatization
gradients(l2Indices,:) = dlupdate( @(g,w)g + l2RegularizationFactor*w, ...
gradients(l2Indices,:), net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
One thing to note is that with adaptive learning rules like Adam and Nadam, it has been found that it is often more effective to apply weight decay directly to the weights instead of the gradients. When applying this to Nadam, it results in the algorithm NadamW. Below is an example on how to use NadamW.
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply decoupled weight regulatization (NadamW)
net.Learnables(l2Indices,:) = dlupdate( @(w)w - learnRate*l2RegularizationFactor*w, ...
net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
Siehe auch
Kategorien
Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!