Error custom training loop: Value to differentiate must be a traced dlarray scalar.
Ältere Kommentare anzeigen
Is it possible to include a Blackbox and still use Automatic Differentiation in MATLAB?
I am trying to do the following.
1) I have 3 input features which are x,y and z locations computed using a custom function (getcondvects_n_k). M such examples. xyz is a dlarray of shape 3-by-M
xyz=dlarray(flip(getcondvects_n_k([], 3, val_vectors),2),'BC');
2) A NN will compute a value of either 0 or 1 for each example
layers = [
featureInputLayer(3,"Name","elementCenterLocations")
fullyConnectedLayer(20,"Name","fclayer1")
batchNormalizationLayer("Name","batchnorm1")
leakyReluLayer(0.3,"Name","leakyrelu1")
fullyConnectedLayer(1,"Name","fclayer2")
sigmoidLayer("Name","sigmoid")];
lgraph = layerGraph(layers);
dlnet=dlnetwork(lgraph);
3) Forward Pass
r=forward(dlnet,xyz);
4) Blackbox
The output from the NN is fed to a seperate function. It is like a custom loss function and computes Loss and derivative of wrt r i.e. dl_dr which is nx-by-ny-by-nz matrix.
R=reshape(double(extractdata(r)),nx, ny,nz);
[loss, dl_dr]=black_box(R, other_inputs);
5) Backward Pass
So I want to use dl_dr to update the weights of the NN
grad = dlgradient(dlarray(dl_dr(:)),dlnet.Learnables,'RetainData',true);
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,loop,learnRate);
6) The Forward Pass, Blackbox and Backward Pass will be in a custom training loop.
I'm getting the error when dlgradient is called. Can you please suggest changes if any? There are no known outputs Y and the Blackbox has many steps that involves matrix inversion. The inputs to the Blackbox cannot be a dlarray.
But the equation relating loss and r is straight forward and hence it's derivative is also straight forward.
Akzeptierte Antwort
Weitere Antworten (0)
Kategorien
Mehr zu Deep Learning Toolbox finden Sie in Hilfe-Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!