Error custom training loop: Value to differentiate must be a traced dlarray scalar.
24 views (last 30 days)
Is it possible to include a Blackbox and still use Automatic Differentiation in MATLAB?
I am trying to do the following.
1) I have 3 input features which are x,y and z locations computed using a custom function (getcondvects_n_k). M such examples. xyz is a dlarray of shape 3-by-M
xyz=dlarray(flip(getcondvects_n_k(, 3, val_vectors),2),'BC');
2) A NN will compute a value of either 0 or 1 for each example
layers = [
lgraph = layerGraph(layers);
3) Forward Pass
The output from the NN is fed to a seperate function. It is like a custom loss function and computes Loss and derivative of wrt r i.e. dl_dr which is nx-by-ny-by-nz matrix.
[loss, dl_dr]=black_box(R, other_inputs);
5) Backward Pass
So I want to use dl_dr to update the weights of the NN
grad = dlgradient(dlarray(dl_dr(:)),dlnet.Learnables,'RetainData',true);
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,loop,learnRate);
6) The Forward Pass, Blackbox and Backward Pass will be in a custom training loop.
I'm getting the error when dlgradient is called. Can you please suggest changes if any? There are no known outputs Y and the Blackbox has many steps that involves matrix inversion. The inputs to the Blackbox cannot be a dlarray.
But the equation relating loss and r is straight forward and hence it's derivative is also straight forward.
Mahesh Taparia on 29 Jul 2021
You are converting the dlarray data to double and then again converting to dlarray to perform automatic differentiation.
To overcome the issue which you are facing, do not convert the dlarray data to double using extractdata function. Once the data is getting extracted then it won't hold the gradients and the property of automatic differentiation will lose. So one thing is you can try to find out some way to directly use input of black_box function in dlarray format, do the required computation on top of that and take the output as dlarray. It will solve the issue.
Hope it will help!