How to use Vanilla SGD solver in training options ?

7 Ansichten (letzte 30 Tage)
Mira mosad
Mira mosad am 20 Dez. 2022
Beantwortet: Meet am 12 Sep. 2024
when i used Vanilla SGD instead of adam solver the code has error : invalid solver name .
how can i use Vanilla SGD instead of adam solver ?
this is my code for traning options part :
options = trainingOptions('sgdm', ...
'MaxEpochs',20,...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');

Antworten (1)

Meet
Meet am 12 Sep. 2024
Hi Mira,
The option for vanilla SGD is not available as a pre-built solver in the “trainingOptions” function. However, you can define a custom SGD solver and training loop according to your preferences.
Below is the code for defining a custom SGD solver and training loop:
Custom SGD Function:
function parameters = sgdStep(parameters,gradients,learnRate)
parameters = parameters - learnRate .* gradients;
end
Custom Training Loop:
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
[X,T] = next(mbq);
% Evaluate the model gradients, state, and loss using dlfeval and the
% modelLoss function and update the network state.
[loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
net.State = state;
% Update the network parameters using SGD.
updateFcn = @(parameters,gradients) sgdStep(parameters,gradients,learnRate);
net = dlupdate(updateFcn,net,gradients);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch);
monitor.Progress = 100 * iteration/numIterations;
end
end
You can refer to the resource below for more information:

Kategorien

Mehr zu Chemistry finden Sie in Help Center und File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by