MATLAB Answers

Problems in defining a custom Neural network layer

15 views (last 30 days)
Ahmad Momani
Ahmad Momani on 23 Apr 2021
Answered: Joss Knight on 29 Apr 2021
Hi,
I am trying to define a layer that perform a state space equation Y=AX+BU within the neural network. the aim of the network is to learn the parameters of matrix A and matrix B then perform the state space equation as below. I used small number of states and inputs for simplifications, but thier numbers can change.
the following code describe the cutom layer. It reshape the input array to the requried vectors and matrcies then perform the state space equation
classdef statespace < nnet.layer.Layer % & nnet.layer.Formattable (Optional)
properties
% (Optional) Layer properties.
% Layer properties go here.
statesNo=1;% default number of states, change if necessary
inputsNo=1;% default number of inputs, change if necessary
end
properties (Learnable)
% (Optional) Layer learnable parameters.
% Layer learnable parameters go here.
% No learnable parameters
end
methods
function layer = statespace(numstates,numinputs)
% (Optional) Create a myLayer.
% This function must have the same name as the class.
% Layer constructor function goes here.
layer.statesNo=numstates;
layer.inputsNo=numinputs;
layer.Name='state space layer';
end
function [Z1] = predict(layer, X1)
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X1, ..., Xn - Input data
% Outputs:
% Z1, ..., Zm - Outputs of layer forward function
% Layer forward function for prediction goes here.
n=layer.statesNo;
u=layer.inputsNo;
Y=X1';
X=reshape(Y(1:n),[n 1]);
U=reshape(Y(n+1:n+u),[u 1]);
A=reshape(Y(n+u+1:n+u+n*n),[n n]);
B=reshape(Y(n+u+n*n+1:n+u+n*n+n*u),[n u]);
% find the output of the state space equation
Z1=A*X+B*U;
end
function [Z1] = forward(layer, X1)
% (Optional) Forward input data through the layer at training
% time and output the result and a memory value.
%
% Inputs:
% layer - Layer to forward propagate through
% X1, ..., Xn - Input data
% Outputs:
% Z1, ..., Zm - Outputs of layer forward function
% memory - Memory value for custom backward propagation
% Layer forward function for training goes here.
n=layer.statesNo;
u=layer.inputsNo;
Y=X1;
X=reshape(Y(1:n),[n 1]);
U=reshape(Y(n+1:n+u),[u 1]);
A=reshape(Y(n+u+1:n+u+n*n),[n n]);
B=reshape(Y(n+u+n*n+1:n+u+n*n+n*u),[n u]);
% find the output of the state space equation
Z1=A*X+B*U;
end
%{
% No need for backprob definition
function [dLdX1, dLdXn, dLdW1, dLdWk] = ...
backward(layer, X1, Z1, dLdZ1)
% (Optional) Backward propagate the derivative of the loss
% function through the layer.
%
% Inputs:
% layer - Layer to backward propagate through
% X1, ..., Xn - Input data
% Z1, ..., Zm - Outputs of layer forward function
% dLdZ1, ..., dLdZm - Gradients propagated from the next layers
% memory - Memory value from forward function
% Outputs:
% dLdX1, ..., dLdXn - Derivatives of the loss with respect to the
% inputs
% dLdW1, ..., dLdWk - Derivatives of the loss with respect to each
% learnable parameter
% Layer backward function goes here.
end
%}
end
end
Trying to verfiey the Lyer will give me the following
>> layer = statespace(2,1);
>> checkLayer(layer,[9,1],'ObservationDimension',2)
Skipping GPU tests. No compatible GPU device found.
Running nnet.checklayer.TestLayerWithoutBackward
.......... ...
Done nnet.checklayer.TestLayerWithoutBackward
__________
Test Summary:
13 Passed, 0 Failed, 0 Incomplete, 4 Skipped.
Time elapsed: 0.058491 seconds.
>>
Adding the state space layer to the following network will give an error about the input to the layer
>> Untitled4
Error using trainNetwork (line 170)
Invalid network.
Error in Untitled4 (line 70)
netold = trainNetwork(XTrain1,YTrain1,lgraph,options);
Caused by:
Layer 'state space layer': Invalid network. Layer statespace does not support sequence input. Try using a
sequence folding layer before the layer.
the code for the above network is shown below "Untitled4"
%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Test Code %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%
n=3;% number of states
u=2;% number of inputs
XTrain1=[1;1;1;1;1];% putting states and inputs into one columen vector [states,inputs]'
YTrain1=ones(3,1); % the expected next state
%%
numFeatures = size(XTrain1,1);
numResponses = size(YTrain1,1);
%%
lgraph = layerGraph();% empty layer graph
tempLayers = sequenceInputLayer(length(XTrain1),"Name","sequence");
lgraph = addLayers(lgraph,tempLayers);
tempLayers = [
lstmLayer(128,"Name","lstm_1")
lstmLayer(128,"Name","lstm_2")
lstmLayer(128,"Name","lstm_3")
fullyConnectedLayer(n*n+n*u,"Name","fc_1")
];
lgraph = addLayers(lgraph,tempLayers);
%
tempLayers = [
concatenationLayer(1,2,"Name","concat")
statespace(3,2)
regressionLayer("Name","MSE")
];
lgraph = addLayers(lgraph,tempLayers);
% clean up helper variable
clear tempLayers;
lgraph = connectLayers(lgraph,"sequence","lstm_1");
lgraph = connectLayers(lgraph,"sequence","concat/in1");
lgraph = connectLayers(lgraph,"fc_1","concat/in2");
%%
plot(lgraph);
%%
options = trainingOptions('adam', ...
'MaxEpochs',10, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',125, ...
'LearnRateDropFactor',0.2, ...
'Verbose',1);
netold = trainNetwork(XTrain1,YTrain1,lgraph,options);
I am not sure what I am doing wrong, any help is appreciated

Answers (1)

Joss Knight
Joss Knight on 29 Apr 2021
Your custom layer doesn't work for input sequences, and it needs to. So the particular error that is happening (which you can see if you set a breakpoint in your custom layer) is that you're trying to transpose a 20x1x3 array, because during configuration your custom layer is being tested using a sample sequence length of 3. The output of a fully connected layer is channels-by-batch-by-time and then your layer receives that. Fix that and you'll get other errors, in the concatenation layer for instance when the batch size is not 1. Your network needs to handle variable batch size and variable sequence length.

Products


Release

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by