Filter löschen
Filter löschen

predict function not working in custom training loop

10 Ansichten (letzte 30 Tage)
Isabella
Isabella am 25 Okt. 2022
Kommentiert: Isabella am 25 Okt. 2022
I am building a custom training loop for a simple LSTM classification network because I need a custom loss function (specifically, 0-1 loss). I have followed a tutorial but when I call the predict function within my custom loss function, I get the error: 'Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.'
I can successfully train the network with trainNetwork, so I am wondering what trainNetwork is doing that I am missing. If I call predict after training with trainNetwork, it works, but not in the training loop.
my input is a 50x1 sequence that is either classified as 1 or 2 (depending on if its average is positive or negative).
My network is defined as follows:
numFeatures = 1; %input data value (50 time points in sequence)
numHiddenUnits = 100;
numClasses = 2; %Left/rigth decision at the end
layers = [ ...
sequenceInputLayer(numFeatures)
lstmLayer(numHiddenUnits,"OutputMode","last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
net = layers;
and my custom loss function is:
function [gradients,state,loss] = customGradients2Lay(net,dlX,Ylabel)
[Y,state]=predict(net,dlX);
loss=loss01(Y,Ylabel);
gradients=dlgradient(loss,net.Learnables);
end
function loss = loss01(Y, T)
if isequal(Y,T)
loss = 0;
else
loss = 1;
end
end
My training loop just calls a random dataset to test on and sends it to predict. The error again is: Undefined function 'predict' for input arguments of type 'nnet.cnn.layer.Layer'.
I also am not sure why its calling nnet.cnn. I even built a custom fully connected layer to try to get around this and it was still calling nnet.cnn class.
What am I missing?

Akzeptierte Antwort

James Gross
James Gross am 25 Okt. 2022
Hello,
To train your network in a custom training loop, you must specify your network as a dlnetwork.
net = dlnetwork(layers);
You should then be able to train and call predict on your network as desired. For examples of how to train using a custom training loop with a dlnetwork, you can refer to one of the following:
I hope this information helps!

Weitere Antworten (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by