Can you plot the gradient for CNNs using trainNetwork?

7 Ansichten (letzte 30 Tage)
Arjun Desai
Arjun Desai am 27 Mai 2018
Beantwortet: Snehal am 27 Mär. 2025
I am using the trainNetwork command to train my network, but noticed that there is no way to plot the gradients over iterations. The trainInfo output does contain some information, but does not seem to contain any information about the gradient.

Antworten (1)

Snehal
Snehal am 27 Mär. 2025
I understand that you want to extract the gradient information while training a CNN and plot this over iterations. While ‘trainNetwork’ function in MATLAB does not directly expose gradients during the training process, there are two possible workarounds that you can follow:
  • Below is a sample code snippet on extracting gradients using ‘dlgradient’:
net = dlnetwork(layers); % Where ‘layers’ refers to a sequence of layers defined previously in the code.
% Assume 'net', 'XBatch', and 'YBatch' are already defined and 'XBatch' is a dlarray
% Forward pass
YPred = forward(net, XBatch);
% Computing loss
loss = crossentropy(YPred, YBatch);
% Compute gradients
gradients = dlgradient(loss, net.Learnables); % 'gradients' now contains the gradients of the loss with respect to the learnable parameters
  • To plot gradients when using ‘trainNetwork’, you can use a custom plot function instead. Information relating to rate of change of parameters like ‘TrainingLoss’and ‘ValidationLoss’ over iterations can be used to monitor and estimate the gradient-related patterns during training.
Below are some documentation links, you can refer to them for more information:
Hope this helps.

Kategorien

Mehr zu Image Data Workflows finden Sie in Help Center und File Exchange

Produkte


Version

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by