Your mistake is dlfeval(@dlgradient,...). You need to put your code that computes the loss and the loss gradients into a function and then pass that to dlfeval.
I can`t solve out this problem, there is always Output argument "varargout{2}" (and possibly others) not assigned a value in the execution with "dlarray/dlgradient" function.
5 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
function [netG, stateG, lossG] = modelGStep(netG, wrappedImage, realImage, stateG, learningRate, beta1, beta2)
% insure GPU dlarray
if ~isa(wrappedImage, 'dlarray')
wrappedImage = dlarray(gpuArray(wrappedImage), 'SSCB');
elseif ~strcmp(underlyingType(wrappedImage), 'gpuArray')
wrappedImage = dlarray(gpuArray(extractdata(wrappedImage)), 'SSCB');
end
if ~isa(realImage, 'dlarray')
realImage = dlarray(gpuArray(realImage), 'SSCB');
elseif ~strcmp(underlyingType(realImage), 'gpuArray')
realImage = dlarray(gpuArray(extractdata(realImage)), 'SSCB');
end
wrappedImage = dlarray(gpuArray(wrappedImage), 'SSCB');
realImage = dlarray(gpuArray(realImage), 'SSCB');
% insure dlfeval use dlgradient
[gradG, lossG] = dlfeval(@dlgradient, lossG, netG.Learnables);
fakeImage = predict(netG, wrappedImage);
lossG = mean((fakeImage - realImage).^2, 'all');
[gradG, lossG] = dlgradient(lossG, netG.Learnables);
[netG, stateG] = adamupdate(netG, gradG, stateG, learningRate, beta1, beta2);
return
end
- this is my function.
- below is my code
for epoch = 1:epochs
for i = 1:size(unwrapImages, 4)
realImage = unwrapImages(:,:,:,i);
wrappedImage = wrappedImages(:,:,:,i);
[netG, stateG, lossG] = modelGStep(netG, wrappedImage, realImage, stateG, learningRate, beta1, beta2);
[lossD, gradD] = modelDStep(netD, realImage, wrappedImage, netG);
[netD, stateD] = adamupdate(netD, gradD, stateD, learningRate, beta1, beta2);
gLosses(epoch) = gLosses(epoch) + double(gather(extractdata(lossG)));
dLosses(epoch) = dLosses(epoch) + double(gather(extractdata(lossD)));
end
gLosses(epoch) = gLosses(epoch) / size(unwrapImages, 4);
dLosses(epoch) = dLosses(epoch) / size(unwrapImages, 4);
fprintf('Epoch %d, Generator Loss: %.4f, Discriminator Loss: %.4f\n', ...
epoch, gLosses(epoch), dLosses(epoch));
end
what should i do to solve this,thanks!
0 Kommentare
Akzeptierte Antwort
Weitere Antworten (0)
Siehe auch
Kategorien
Mehr zu Time Series Objects finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!