Using transformer neural network for classification task

15 Ansichten (letzte 30 Tage)
haohaoxuexi1
haohaoxuexi1 am 28 Jul. 2024
Kommentiert: Joss Knight am 13 Aug. 2024
numChannels = inputSize;
maxPosition = 256;
numHeads = 4;
numKeyChannels = numHeads*32;
layers = [
sequenceInputLayer(numChannels,Name="input")
positionEmbeddingLayer(numChannels, maxPosition, Name="pos-emb");
additionLayer(2, Name="add")
selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal')
selfAttentionLayer(numHeads,numKeyChannels)
indexing1dLayer("last")
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph, "input", "add/in2");
maxEpochs = 100;
miniBatchSize = 32;
learningRate = 0.001;
solver = 'adam';
shuffle = 'every-epoch';
gradientThreshold = 10;
executionEnvironment = "auto"; % chooses local GPU if available, otherwise CPU
options = trainingOptions(solver, ...
'Plots','training-progress', ...
'MaxEpochs', maxEpochs, ...
'MiniBatchSize', miniBatchSize, ...
'Shuffle', shuffle, ...
'InitialLearnRate', learningRate, ...
'GradientThreshold', gradientThreshold, ...
'ExecutionEnvironment', executionEnvironment);
The input size is 12, so there are 12 features.
numClasses is 4, so I am classifying it into 4 class.
But it gives the following error when I try to run it
"
Error in test123_20240727 (line 195)
net=trainNetwork(XTrain, YTrain, layers, options);
Caused by:
Layer 'add': Unconnected input. Each layer input must be connected to the output of another layer.
"
line 195 is "net=trainNetwork(XTrain, YTrain, layers, options);"
Can anyone help me with this?
  7 Kommentare
Umar
Umar am 29 Jul. 2024
Hi @ haohaoxuexi1,
If you are still having issues with modifying your code, please let us know. We will be happy to help you out.
haohaoxuexi1
haohaoxuexi1 am 29 Jul. 2024
@Umar Hi Umar, I am good at the moment. Will let u know if I have further question.

Melden Sie sich an, um zu kommentieren.

Akzeptierte Antwort

Joss Knight
Joss Knight am 29 Jul. 2024

You've passed layers instead of lgraph to trainNetwork.

  2 Kommentare
Umar
Umar am 29 Jul. 2024
@Joss Knight, Thanks for jumping in. Please advice how to use lgraph to trainNetwork by providing code snippet. Again, thanks for your cooperation.
Joss Knight
Joss Knight am 13 Aug. 2024
net=trainNetwork(XTrain, YTrain, lgraph, options);
instead of
net=trainNetwork(XTrain, YTrain, layers, options);

Melden Sie sich an, um zu kommentieren.

Weitere Antworten (0)

Kategorien

Mehr zu Image Data Workflows finden Sie in Help Center und File Exchange

Produkte


Version

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by