How do I debug a convolutional neural network with a custom training loop that is not learning?
    9 Ansichten (letzte 30 Tage)
  
       Ältere Kommentare anzeigen
    
    Marissa Brown
 am 23 Jun. 2023
  
    
    
    
    
    Kommentiert: Richard
    
 am 26 Jun. 2023
            Hello! I have been trying to design a CNN for image analysis. The CNN is training on simulated images of size 132 x 132 x 6 (spatial, spatial, channel). The simulated images are computed using a bi-exponential equation of the form  . In the CNN, the input images are forward passed through the network to generate four feature maps (
. In the CNN, the input images are forward passed through the network to generate four feature maps ( ,
,  ,
,  , and
, and  ) which then are scaled and used to calculated the predicted image signals,
) which then are scaled and used to calculated the predicted image signals,  . The predicted image signals
. The predicted image signals  are then compared to the input image signals S using the mean squared error loss function and the gradients are updated. The problem is the network is not learning. After some inspection I noticed that the gradients are all going to zero, however I'm not sure how to fix this problem. I have tried changing the learning rate, adam v. sgdm optimizers, and the mini-batch size, however I encounter the same problem. Any advice/feedback is greatly appreciated!
 are then compared to the input image signals S using the mean squared error loss function and the gradients are updated. The problem is the network is not learning. After some inspection I noticed that the gradients are all going to zero, however I'm not sure how to fix this problem. I have tried changing the learning rate, adam v. sgdm optimizers, and the mini-batch size, however I encounter the same problem. Any advice/feedback is greatly appreciated!
 . In the CNN, the input images are forward passed through the network to generate four feature maps (
. In the CNN, the input images are forward passed through the network to generate four feature maps ( ,
,  ,
,  , and
, and  ) which then are scaled and used to calculated the predicted image signals,
) which then are scaled and used to calculated the predicted image signals,  . The predicted image signals
. The predicted image signals  are then compared to the input image signals S using the mean squared error loss function and the gradients are updated. The problem is the network is not learning. After some inspection I noticed that the gradients are all going to zero, however I'm not sure how to fix this problem. I have tried changing the learning rate, adam v. sgdm optimizers, and the mini-batch size, however I encounter the same problem. Any advice/feedback is greatly appreciated!
 are then compared to the input image signals S using the mean squared error loss function and the gradients are updated. The problem is the network is not learning. After some inspection I noticed that the gradients are all going to zero, however I'm not sure how to fix this problem. I have tried changing the learning rate, adam v. sgdm optimizers, and the mini-batch size, however I encounter the same problem. Any advice/feedback is greatly appreciated!Also, I have removed parts of the code to make it as simple as possible for the time being, but will add in validation and testing loops.
% Image Parameters
rng(1);
imageSize = [132, 132];
bValue = [50 100 150 250 500 800]; % non-zero diffusion weightings
numbVal = length(bValue);
minDf = 0.0017; 
maxDf = 0.107; 
minf = 0.1; 
maxf = 0.5;
minDs = 0.0003; 
maxDs = 0.0017;
DfSim = minDf + (maxDf-minDf).*rand(10,1);
fSim = minf + (maxf-minf).*rand(10,1);
DsSim = minDs + (maxDs-minDs).*rand(10,1);
numIm = length(DfSim) * length(fSim) * length(DsSim);      % number of 132 x 132 x 6 images
tissue = ones(imageSize);
bValue = reshape(bValue, [1,1,numbVal]); % Reshape bValue for matrix operation
% Prepare a directory to store the simulated images
outputDir = fullfile(tempdir, 'SimulatedDW-MRI');
if ~exist(outputDir, 'dir')
    mkdir(outputDir);
end
% Initialize a table to store the image file paths and parameters
fprintf('Total simulated images: %d\n', numIm);
imageData = table('Size', [0 4],...
          'VariableTypes', {'cell', 'double', 'double', 'double'},...
          'VariableNames', {'imageFilePath', 'DfSim', 'fSim', 'DsSim'});
% Start the timer
tic;
% Loop through each combination of DfSim, fSim, and DsSim
imageIdx = 0;
S = zeros([imageSize length(bValue) numIm]);
for DfIdx = 1:length(DfSim)
    for fIdx = 1:length(fSim)
        for DsIdx = 1:length(DsSim)
            imageIdx = imageIdx + 1;
            % Calculate the diffusion signal for each b value for each channel
            S(:,:,:,imageIdx) = tissue .* ((fSim(fIdx) .* exp(-bValue .* DfSim(DfIdx))) + ((1-fSim(fIdx)) .* exp(-bValue .* DsSim(DsIdx))));
            % Track progress
            fprintf('Processing image %d out of %d\n', imageIdx, numIm);
        end
    end
end
for imageIdx = 1:numIm
    fileName = sprintf('%s/image%d.mat', outputDir, imageIdx);      % Write the image to a .mat file
    S_single = S(:,:,:,imageIdx);
    save(fileName, 'S_single');
    DfIdx = ceil(imageIdx / (length(fSim)*length(DsSim)));
    fIdx = ceil((imageIdx - (DfIdx-1)*length(fSim)*length(DsSim)) / length(DsSim));
    DsIdx = imageIdx - (DfIdx-1)*length(fSim)*length(DsSim) - (fIdx-1)*length(DsSim);
    imageData(imageIdx, :) = {fileName, DfSim(DfIdx), fSim(fIdx), DsSim(DsIdx)};
    fprintf('Saving image %d out of %d\n', imageIdx, numIm);
end
elapsedTime = toc;
fprintf('Computation time: %.2f seconds\n', elapsedTime);
%% Split data in training, validation, and testing sets
trainSplit = 0.8;
valSplit = 0.1;
testSplit = 0.1;
n = height(imageData);
idx = randperm(n);
trainIdx = idx(1:round(trainSplit*n));
valIdx = idx(round(trainSplit*n)+1:round((trainSplit+valSplit)*n));
testIdx = idx(round((trainSplit+valSplit)*n)+1:end);
imageDataTrain = imageData(trainIdx, :);
imageDataVal = imageData(valIdx, :);
imageDataTest = imageData(testIdx, :);
trainImds = fileDatastore(imageDataTrain.imageFilePath, ...
                          'ReadFcn' , @(filename) double(load(filename).S_single), ...
                          'FileExtensions', '.mat');
trainLabelsDatastore = arrayDatastore(imageDataTrain{:, {'DfSim', 'fSim', 'DsSim'}});
trainCombinedDatastore = combine(trainImds, trainLabelsDatastore);
valImds = fileDatastore(imageDataVal.imageFilePath, ...
                          'ReadFcn' , @(filename) double(load(filename).S_single), ...
                          'FileExtensions', '.mat');
valLabelsDatastore = arrayDatastore(imageDataVal{:, {'DfSim', 'fSim', 'DsSim'}});
valCombinedDatastore = combine(valImds, valLabelsDatastore);
testImds = fileDatastore(imageDataTest.imageFilePath, ...
                          'ReadFcn' , @(filename) double(load(filename).S_single), ...
                          'FileExtensions', '.mat');
testLabelsDatastore = arrayDatastore(imageDataTest{:, {'DfSim', 'fSim', 'DsSim'}});
testCombinedDatastore = combine(testImds, testLabelsDatastore);
%% Define the network layers
lgraph = layerGraph();
Layers = [
    imageInputLayer([132 132 6],"Name","imageinput","Normalization","none")
    convolution2dLayer([1 1],32,"Name","conv_1","Padding","same")
    batchNormalizationLayer("Name","batchnorm_1")
    leakyReluLayer("Name","relu_1")
    dropoutLayer(0.02,"Name","dropout_1")
    convolution2dLayer([3 3],32,"Name","conv_2","Padding","same")
    leakyReluLayer("Name","relu_2")
    dropoutLayer(0.02,"Name","dropout_2")
    convolution2dLayer([1 1],64,"Name","conv_3","Padding","same")
    batchNormalizationLayer("Name","batchnorm_2")
    leakyReluLayer("Name","relu_3")
    dropoutLayer(0.02,"Name","dropout_3")
    convolution2dLayer([3 3],64,"Name","conv_4","Padding","same")
    leakyReluLayer("Name","relu_4")
    dropoutLayer(0.02,"Name","dropout_4")
    convolution2dLayer([1 1],128,"Name","conv_5","Padding","same")
    batchNormalizationLayer("Name","batchnorm_3")
    leakyReluLayer("Name","relu_5")
    dropoutLayer(0.02,"Name","dropout_5")
    convolution2dLayer([3 3],128,"Name","conv_6","Padding","same")
    leakyReluLayer("Name","relu_6")
    dropoutLayer(0.02,"Name","dropout_6")
    convolution2dLayer([1 1],64,"Name","conv_7","Padding","same")
    batchNormalizationLayer("Name","batchnorm_4")
    leakyReluLayer("Name","relu_7")
    dropoutLayer(0.02,"Name","dropout_7")
    convolution2dLayer([3 3],64,"Name","conv_8","Padding","same")
    leakyReluLayer("Name","relu_8")
    dropoutLayer(0.02,"Name","dropout_8")
    convolution2dLayer([1 1],32,"Name","conv_9","Padding","same")
    batchNormalizationLayer("Name","batchnorm_5")
    leakyReluLayer("Name","relu_9")
    dropoutLayer(0.02,"Name","dropout_9")
    convolution2dLayer([3 3],32,"Name","conv_10","Padding","same")
    leakyReluLayer("Name","relu_10")
    dropoutLayer(0.02,"Name","dropout_10")
    convolution2dLayer([1 1],4,"Name","conv_11","Padding","same")
    sigmoidLayer("Name","sigmoid")];
lgraph = addLayers(lgraph,Layers);  
dlnet = dlnetwork(lgraph);
plot(lgraph);
%% Training loop
numEpochs = 200;
miniBatchSize = 10;
initialLearnRate = 0.01;
decay = 0.00001;
gradDecay = 0.9;
sqGradDecay = 0.999;
mbq = minibatchqueue(trainCombinedDatastore,...
    'MiniBatchSize', miniBatchSize,...
    'MiniBatchFormat', {'SSCB', 'CB'}, ...
    'OutputAsDlarray', [1, 1],...
    'OutputEnvironment', 'auto');
averageGrad = [];
averageSqGrad = [];
numObservationsTrain = imageIdx;
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
plots = 'training-progress';
if strcmp(plots, 'training-progress')
    figure
    lineLossTrain = animatedline;
    xlabel("Total Iterations")
    ylabel("Loss")
end
epoch = 0;
iteration = 0;
start = tic;
% Loop over epochs.
while epoch < numEpochs
    epoch = epoch + 1;
    % Shuffle data.
    shuffle(mbq);
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        % Read mini-batch of data.
        [dlX, dlT] = next(mbq);
        [loss, gradients, state] = dlfeval(@modelLoss,dlnet,dlX);
        dlnet.State = state;
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        % Update network parameters
        [dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,gradients,averageGrad,averageSqGrad,...
            iteration, learnRate, gradDecay, sqGradDecay);
        % Extract weights of first convolution layer
        conv1Weights = dlnet.Layers(2).Weights;
        % Print or save the weights
        disp('Weights of conv_1 layer:');
        disp(conv1Weights);
        if strcmp(plots, 'training-progress')
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain, iteration, double(gather(extractdata(loss))));
            title("Epoch: " + epoch + " , Elapsed: " + string(D));
            drawnow
        end
    end
end
%% Custom loss function
function [loss, gradients,state] = modelLoss(dlnet, dlX)
% Forward data through network.
[dlY, state] = forward(dlnet, dlX);
% Calculate parameter maps
fMap = dlY(:,:,1,:).*0.5;
DfMap = dlY(:,:,2,:).*0.107;
S0Map = (dlY(:,:,3,:).*0.6) + 0.7;
DsMap = dlY(:,:,4,:).*0.0017;
% diffusion weightings
dlB = [50 100 150 250 500 800];
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX));
for b = 1:length(dlB)
    Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
% Convert Spred to dlarray
Spred = dlarray(Spred, 'SSCB');
% Calculate the mse loss
loss = mse(Spred, dlX);
% Calculate gradients of loss with respect to learnable parameters.
gradients = dlgradient(loss, dlnet.Learnables);
end
0 Kommentare
Akzeptierte Antwort
  Richard
    
 am 26 Jun. 2023
        These lines of code:
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX));
for b = 1:length(dlB)
    Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
% Convert Spred to dlarray
Spred = dlarray(Spred, 'SSCB');
are creating a variable, Spred, that does not contain a traced dependency on the output of the network.  This means that your mse() call is in fact only tracing a dependency on the original input dlX, therefore the gradients of the loss with respect to learnables is zeros.
Try this instead to create an Spred that incorporates the dependency on the network output:
% Use model outputs to predict the diffusion signal for each image
% in mini batch
Spred = zeros(size(dlX), 'like', dlX);
for b = 1:length(dlB)
    Spred(:,:,b,:) = S0Map .* (fMap.*exp(-dlB(b).*DfMap) + (1 - fMap).*exp(-dlB(b).*DsMap));
end
The 'like' syntax for zeros() constructs a zeros dlarray that is tracing, like its input, and your indexing within the loop will then be captured.  In your original version, because Spred is created as a plain double array,  the indexing which places values into Spred(:,:,b,:) is casting the computed and traced right-hand side into a plain double value which loses the trace information that dlgradient depends on.
Incidentally I think you can also remove the loop entirely by reshaping dlB into a 3D vector and relying on implicit expansion, which should be faster:
dlB = reshape(dlB, 1,1,[]);
Spred = S0Map .* (fMap.*exp(-dlB.*DfMap) + (1 - fMap).*exp(-dlB.*DsMap));
2 Kommentare
  Richard
    
 am 26 Jun. 2023
				Hi Marissa, 
10 samples is quite a small minibatchsize and I think this is causing you to see a lot of noise in the gradients. When I increase the minibatchsize to 64 I see a much smoother curve::

Weitere Antworten (1)
  Aniketh
      
 am 25 Jun. 2023
        A very probable cause for this, and what I have exeperienced myself a few times is Initialization, check the initialization of your network's weights. If the weights are initialized too small, it can lead to vanishing gradients. Consider using a suitable initialization method, such as Xavier or He initialization, which helps to maintain a reasonable range for the weights.
Another thing you could consider is your Network architecture, evaluate the depth and complexity of your network architecture. Very deep networks are more susceptible to vanishing gradients. If your network is too deep, consider reducing the number of layers or introducing skip connections (e.g., residual connections) to facilitate gradient flow.
1 Kommentar
Siehe auch
Kategorien
				Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!





