What does function predict() in Deep Learning Toolbox do?

3 Ansichten (letzte 30 Tage)
Song Decn
Song Decn am 8 Mai 2021
Bearbeitet: Vidip am 21 Feb. 2024
Hi, I follow the example of this
and made a little modification, namely by not using predict() function but calling predictAndUpdateState() to predict the target one by one.
In this way I get a much worse predition result (brown line) as predict() (yellow line).
Can anyone explain this?
The only different part is
% opt1: pure use feature variables as input
net = resetState(net);
YPred = [];
for i = 1:numel(XTest)
[net, temp] = predictAndUpdateState(net, XTest(:,i), 'ExecutionEnvironment', 'cpu');
YPred(:,i) = cell2mat(temp);
end
y1 = YPred;
Whole codes:
[~,~,data] = xlsread('ET_1.xlsx');
data_mat = cell2mat(data);
XTrain = (data_mat(:,4:8))';
XTrain = num2cell(XTrain,1);
YTrain = (data_mat(:,3))';
YTrain = num2cell(YTrain,1);
%%Define Network Architecture
featureDimension = size(XTrain{1},1);
numResponses = size(YTrain{1},1);
numHiddenUnits = 500;
layers = [ ...
sequenceInputLayer(featureDimension)
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(500) %%50
dropoutLayer(0.1) %%0.5
fullyConnectedLayer(numResponses)
regressionLayer
];
maxepochs = 500;
miniBatchSize = 1;
options = trainingOptions('adam', ... %%adam
'MaxEpochs',maxepochs, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',125, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
%%Train the Network
net = trainNetwork(XTrain,YTrain,layers,options);
%% Test the Network
[~,~,data] = xlsread('ET_2.xlsx');
data_mat = cell2mat(data);
XTest = (data_mat(:,4:8))'; XTest = num2cell(XTest,1);
YTest = (data_mat(:,3))'; YTest = num2cell(YTest,1);
% opt1: pure use feature variables as input
net = resetState(net);
YPred = [];
for i = 1:numel(XTest)
[net, temp] = predictAndUpdateState(net, XTest(:,i), 'ExecutionEnvironment', 'cpu');
YPred(:,i) = cell2mat(temp);
end
y1 = YPred;
% opt2: predict()
net = resetState(net);
YPred = predict(net, XTest);
y2 = (cell2mat(YPred)); %have to transpose as plot plots columns
%%
figure; hold all
yRef = (cell2mat(YTest)');
plot(yRef, '-o')
plot(y1, '-x')
plot(y2, '-s')
  1 Kommentar
Song Decn
Song Decn am 10 Mai 2021
% Opt1:
% yTrain = predict(net, xTrainStandardized);
% yTrain = cell2mat(yTrain);
% Opt2:
% yTrain = [];
% for i = 1:numel(xTrainStandardized)
% [net, tmp] = predictAndUpdateState(net, xTrainStandardized(i));
% yTrain(i) = cell2mat(tmp);
% end
% Opt3:
[net, tmp] = predictAndUpdateState(net, xTrainStandardized);
yTrain = cell2mat(tmp);
these 3 ways to calculate responses give different values? Why?

Melden Sie sich an, um zu kommentieren.

Antworten (1)

Vidip
Vidip am 21 Feb. 2024
Bearbeitet: Vidip am 21 Feb. 2024
The reason you are not getting good results with ‘predictAndUpdateState’ in a loop compared to using ‘predict’ is due to how the LSTM network's state is managed between predictions. The predict function treats each sequence as independent and resets the LSTM state automatically between each prediction, which is appropriate when your test sequences are not temporally related. However, when using ‘predictAndUpdateState’ in a loop without resetting the state after each prediction, the LSTM network's internal state carries over from one prediction to the next.
This means that the network's prediction for each data point is influenced by all the previous data points, which is not suitable if the sequences in ‘XTest’ are supposed to be independent. The accumulation of state information from unrelated sequences can lead to inaccurate predictions, as the network is incorrectly using historical context from separate sequences to make its predictions.
For further information, refer to the documentation link below:

Kategorien

Mehr zu Sequence and Numeric Feature Data Workflows finden Sie in Help Center und File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by