MATLAB Answers

Can I hold 2 batches of dlnetwork gradients and update network parameters in 1 operation?

2 views (last 30 days)
Zh Y
Zh Y on 19 Mar 2021
Commented: Zh Y on 24 Mar 2021
Due to the limitation of GPU memory, a deeplearning network can't learn like 16 samples in a batch.
So can I compute the gradients for a batch of 8 samples, and update the network gradients with 2 batches' gradients?
If I compute the gradients of a deeplearning network by
[gradients,state,loss] = dlfeval(@modelGradient,dlNet,xTrain,yTrain);
So after 2 batches, I get gradients1, gradients2, state1, state2, loss1, and loss2.
For my instant opinion, I think the total gradients should be the mean of gradients1 and gradients2.
But how can I compute the state values? Is it also the mean of state1 and state2? Thank you.

Accepted Answer

Joss Knight
Joss Knight on 21 Mar 2021
Yes, absolutely, just sum the gradients until your batch size is the size you want, then update the model. The principle is exactly the same is training a model on multiple GPUs (or CPUs).
The State update depends on the State. This example shows you how to aggregate batch norm state. The function aggregateState is of interest here. Instead of using gplus you would just be aggregating over your 'sub'-iterations.
function state = aggregateState(state,factor)
numrows = size(state,1);
for j = 1:numrows
isBatchNormalizationState = state.Parameter(j) =="TrainedMean"...
&& state.Parameter(j+1) =="TrainedVariance"...
&& state.Layer(j) == state.Layer(j+1);
if isBatchNormalizationState
meanVal = state.Value{j};
varVal = state.Value{j+1};
% Calculate combined mean
combinedMean = gplus(factor*meanVal);
% Caclulate combined variance terms to sum
combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2);
% Update state
state.Value(j) = {combinedMean};
state.Value(j+1) = {gplus(combinedVarTerm)};
end
end
end

More Answers (0)

Products


Release

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by