How to create a transformer network for sequence to sequence classification task?
Ältere Kommentare anzeigen
I am currently trying to use MATLAB to complete a task of classifying time series using a transformer network. The following is my code, but I cannot solve the error after compiling.
lgraph = [ ...
sequenceInputLayer(InputSize,Name="input")
positionEmbeddingLayer(InputSize,maxPosition,Name="pos-emb");
additionLayer(2, Name="embed_add");
selfAttentionLayer(numHeads,numKeyChannels) % self attention
additionLayer(2,Name="attention_add") % residual connection around attention
layerNormalizationLayer(Name="attention_norm") % layer norm
fullyConnectedLayer(feedforwardHiddenSize) % feedforward part 1
reluLayer % nonlinear activation
fullyConnectedLayer(attentionHiddenSize) % feedforward part 2
additionLayer(2,Name="feedforward_add") % residual connection around feedforward
layerNormalizationLayer() % layer norm
% selfAttentionLayer(numHeads,numKeyChannels,'AttentionMask','causal');
% selfAttentionLayer(numHeads,numKeyChannels);
indexing1dLayer("last")
fullyConnectedLayer(NumClass)
softmaxLayer
classificationLayer];
% Layers = layerGraph(lgraph);
% Layers = connectLayers(Layers,"input","add/in2");
net = dlnetwork(lgraph,Initialize=false);
net = connectLayers(net,"embed_add","attention_add/in2");
net = connectLayers(net,"pos-emb","embed_add/in2");
net = connectLayers(net,"attention_norm","feedforward_add/in2");
% net = connectLayers(net,"encoder1_out","attention2_add/in2");
% net = connectLayers(net,"attention2_norm","feedforward2_add/in2");
net = initialize(net);
Antworten (1)
Prasanna
am 9 Sep. 2024
0 Stimmen
Hi veritas,
The error you're encountering is due to the use of the ‘classificationLayer’, which is not supported in the context of a ‘dlnetwork‘ object because ‘dlnetwork’ is designed for custom training loops and does not require an explicit output layer like ‘classificationLayer’. Instead, you should handle the loss calculation separately during training.
Here's how you can modify your setup to avoid using classificationLayer:
- Remove the ‘classificationLayer’ from your layer graph definition.
- With ‘dlnetwork’, you typically use a custom training loop where you manually compute the loss and update the model parameters.
- Use a loss function such as cross-entropy directly in your training loop.
To perform the above, you can use the ‘trainnet’ function instead of train ‘dlnetwork’ objects and set the loss function to ‘crossentropy’ instead. For more references on the functions, refer the following documentation:
- ‘trainnet’: https://www.mathworks.com/help/deeplearning/ref/trainnet.html
- Detected output layer in dlnetwork: https://www.mathworks.com/matlabcentral/answers/1920960
- dlnetwork training problem with outer layer: https://www.mathworks.com/matlabcentral/answers/2112261
- Classify sequence data using Deep learning: https://www.mathworks.com/help/deeplearning/ug/classify-sequence-data-using-lstm-networks.html
Hope this helps!
Kategorien
Mehr zu Deep Learning Toolbox finden Sie in Hilfe-Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!