Develop WGAN-GP for 3-D image

3 Ansichten (letzte 30 Tage)
Shuaibin WAN
Shuaibin WAN am 23 Mär. 2022
Kommentiert: Shuaibin WAN am 3 Apr. 2022
Hi, I am beginner in using MATLAB to develop generative adversarial networks (GANs). Based on the MATLAB WGAN-GP tutorial, I have developed a WGAN-GP model for 3-D images (H x W x D x C). The modified 'modelGradientsD' function is shown as below.
function [gradientsD, lossD, lossDUnregularized, D_X, D_G_Z1] = modelGradientsD(dlnetG, dlnetD, dlZ, dlX, lambda)
% Calculate the prediction for training images with D
dlYPred = forward(dlnetD, dlX);
% Calculate the prediction for G-generated images with D
dlXGenerated = forward(dlnetG, dlZ);
dlYPredGenerated = forward(dlnetD, dlXGenerated);
% Calculate D(X) and D(G(Z))
D_X = mean(dlYPred);
D_G_Z1 = mean(dlYPredGenerated);
% Calculate the unregularized loss
lossDUnregularized = D_G_Z1 - D_X;
% Get the interpolated image from the training and generated images
epsilon = rand([1 1 1 1 size(dlX,5)], 'like', dlX);
dlXInterpolated = epsilon.*dlX + (1-epsilon).*dlXGenerated;
dlYPredInterpolated = forward(dlnetD, dlXInterpolated);
% Calculate the gradient penalty
gradientsInterpolated = dlgradient(sum(dlYPredInterpolated), dlXInterpolated, 'EnableHigherDerivatives', true);
gradientsInterpolatedNorm = sqrt(sum(gradientsInterpolated.^2,1:4) + 1e-10);
gradientPenalty = lambda.*mean((gradientsInterpolatedNorm - 1).^2);
% Calculate the loss with gradient penalty
lossD = lossDUnregularized + gradientPenalty;
% Calculate the gradients of the loss with respect to learnable parameters
gradientsD = dlgradient(lossD, dlnetD.Learnables);
end
When running the program, however, an error pops up (as shown below). It seems that there is something wrong in calculating 'gradientsD'. After many debug attempts, I find that removing 'EnableHigherDerivatives' from calculating 'gradientsInterpolated' can make it. But the WGAN-GP perform not well, and I have several questions: (1) Does removing 'EnableHigherDerivatives' affect the model training significantly? (2) Is there robustness issue in the 'dlgradient' function? (3) Are there other solutions to this error?
I really appreciate it if you could offer any idea or suggestion. Thanks a lot!
Error using +
Arrays have incompatible sizes for this operation.
Error in gpuArray/internal_dlconv (line 57)
stride, dilation, numGroups) + bias;
Error in deep.internal.recording.operations.DlconvBackwardOp/backward (line 88)
ddZ2 = internal_dlconv(ddX,weights,zeroBias,op.Args{:});
Error in deep.internal.recording.RecordingArray/backwardPass (line 89)
grad = backwardTape(tm,{y},{initialAdjoint},x,retainData,false,0);
Error in dlarray/dlgradient (line 132)
[grad,isTracedGrad] = backwardPass(y,xc,pvpairs{:});
Error in WGANGP_V1>modelGradientsD (line 442)
gradientsD = dlgradient(lossD, dlnetD.Learnables);
Error in deep.internal.dlfeval (line 17)
[varargout{1:nargout}] = fun(x{:});
Error in dlfeval (line 40)
[varargout{1:nargout}] = deep.internal.dlfeval(fun,varargin{:});
Error in WGANGP_V1 (line 226)
[gradientsD, lossD, lossDUnregularized, D_X, D_G_Z1] = dlfeval(@modelGradientsD, dlnetG, dlnetD, dlZ, dlX, lambda);

Akzeptierte Antwort

Joss Knight
Joss Knight am 24 Mär. 2022
You definitely need to use EnableHigherOrderDerivatives here because you are including computed gradients in the loss term. Without it your training will not work correctly.
It looks like this is a bug with higher order derivatives and 3-D data, which is fixed in R2022a. Can you get the latest version of MATLAB?
  3 Kommentare
Joss Knight
Joss Knight am 2 Apr. 2022
I have requested for this to be fixed in a future update of R2021b.
Shuaibin WAN
Shuaibin WAN am 3 Apr. 2022
That sounds GREAT! Thank you so much, sir.

Melden Sie sich an, um zu kommentieren.

Weitere Antworten (0)

Kategorien

Mehr zu Image Processing Toolbox finden Sie in Help Center und File Exchange

Produkte


Version

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by