Why Are Hidden State and Cell State Vectors Zero After Training an LSTM Model with trainNetwork Functionality?
4 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
Shubham Baisthakur
am 10 Okt. 2023
Kommentiert: Shubham Baisthakur
am 20 Okt. 2023
I am training an LSTM model using the trainNetwork functionality and follwing is the architecture of my model:
layers = [ ...
sequenceInputLayer(size(X_train{1},1))
layerNormalizationLayer
lstmLayer(x.num_hidden_units,'OutputMode','sequence')
fullyConnectedLayer(x.num_layers_ffnn)
dropoutLayer(0.1)
fullyConnectedLayer(1)
regressionLayer];
And I am training this using the following command:
options = trainingOptions('adam', ...
'MaxEpochs', 75, ...
'MiniBatchSize', x.batch_size, ...
'SequenceLength', 'longest', ...
'Shuffle', 'once', ...
'L2Regularization',0.01,...
'ValidationData',{X_val,Y_val}, ...
'ValidationFrequency',10,...
'Verbose',false,...
'ExecutionEnvironment','multi-gpu');
% Train the LSTM network
net = trainNetwork(X_train, Y_train, layers, options);
After training the model, the Hidden state and Cell state values for the LSTM layer is a vector of zeros. Why is this happening? I expect these vectors to have non-zero values to ensure the long term dependency between input and output parameters is captured.
0 Kommentare
Akzeptierte Antwort
Neha
am 20 Okt. 2023
Hi Shubham,
The LSTM (Long Short-Term Memory) layer in a neural network is designed to remember values over arbitrary time intervals which indeed helps in maintaining and learning long-term dependencies. However, after training, the hidden and cell states of the LSTM layer are reset to zero. This is standard behavior for LSTMs, and it doesn't mean that the LSTM layer has not learned anything or that it's not working properly.
If you want to maintain the state of LSTM for some reason (like in case of time series prediction where you want the model to remember the state from the previous sequence), you can refer to the explanation for Open Loop Forecasting and Closed Loop Forecasting in the following documentation link:
Here "predictAndUpdateState" function has been used which updates the network state at every timestep.
Hope this helps!
Weitere Antworten (0)
Siehe auch
Kategorien
Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!