Filter löschen
Filter löschen

Initial State Dynamical System LSTM Network

4 Ansichten (letzte 30 Tage)
Michael Hesse
Michael Hesse am 18 Nov. 2020
Kommentiert: Michael Hesse am 19 Nov. 2020
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures);
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
net = resetState(net);
xpred = x0;
for i = 1 : length(t)-1
[net, xpred(:, i+1)] = predictAndUpdateState(net, xpred(:, i));
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');
This is an example code where I want to predict the trajectory of a pendulum via LSTM neural network. How can I provide the initial state x0 into the network? If you look at the figure the second state directly jumps from x0 to the state [0, 0]'. Why does this happen?
  1 Kommentar
Michael Hesse
Michael Hesse am 19 Nov. 2020
Here is a possible workaround. Instead of learning the next state, one can learn the difference to the next state.
%% - cleanup
clear;
close all;
clc;
%% - data
t = linspace(0, 5, 1000);
odefcn = @(t, x) [x(2, :); 10*sin(x(1, :))-x(2, :)];
x0 = [pi/2, 0]';
[~, x] = ode45(odefcn, t, x0);
x = x';
X = x(:, 1:end-1);
Y = x(:, 2:end) - x(:, 1:end-1);
%% - define and train lstm network
numFeatures = 2;
numResponses = 2;
numHiddenUnits = 200;
layers = [sequenceInputLayer(numFeatures, 'Normalization', 'zscore');
lstmLayer(numHiddenUnits);
fullyConnectedLayer(numResponses);
regressionLayer];
opts = trainingOptions('adam', 'MaxEpochs', 100, 'Plots', 'training-progress');
net = trainNetwork(X, Y, layers, opts);
%% - prediction
xpred = x0;
for i = 1 : length(t)-1
[net, dxpred] = predictAndUpdateState(net, xpred(:, i));
xpred(:, i+1) = xpred(:, i) + dxpred;
end
%% - plotting
figure(1);
plot(t, x);
hold on;
grid on;
plot(t, xpred, '--');

Melden Sie sich an, um zu kommentieren.

Antworten (0)

Kategorien

Mehr zu Sequence and Numeric Feature Data Workflows finden Sie in Help Center und File Exchange

Produkte


Version

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by