- attentionLayer documentation - https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layer.attentionlayer.html
- Details on cross attention (and other types of attention mechanisms) - https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention
- Details on how training of neural networks works - https://minitorch.github.io/module1/backpropagate/
When I use attentionLayer to input QKV, Matlab tells me error
4 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
I use Matlab 2024b realase to construct lstm-attention-lstm network, the key and value come from the output of the lstm1, and the query comes from lstm2's output, however when I make this, the Matlab tells me that there is a net circle error. How can I achieve me goal whithout error. Thank you very much.
Here is my code and layers:
net = dlnetwork;
layers = [
sequenceInputLayer(InputChannels,Name="input")
lstmLayer(hidden_num_lstm1, 'OutputMode', 'sequence', Name='lstm_encoder')
dropoutLayer(0.2, Name='dropout_1')
layerNormalizationLayer(Name='batchnormal_1')
fullyConnectedLayer(AttentionChannels, Name="key") % Key
attentionLayer(NumHeads, Name="cross-attention")
layerNormalizationLayer
lstmLayer(hidden_num_lstm2, 'OutputMode', 'sequence', Name='lstm_decoder')
dropoutLayer(0.2, Name='dropout_2')
layerNormalizationLayer(Name='batchnormal_2')
fullyConnectedLayer(64, Name='fc')
fullyConnectedLayer(OutputChannels, Name='fc_out')
];
net = addLayers(net,layers);
net = connectLayers(net,'key','cross-attention/key');
net = disconnectLayers(net,'key','cross-attention/query');
plot(net);
layers = [
fullyConnectedLayer(AttentionChannels, Name="value") % Value
];
net = addLayers(net,layers);
net = connectLayers(net,'batchnormal_1','value');
net = connectLayers(net,'value','cross-attention/value');
plot(net);
layers = [
fullyConnectedLayer(AttentionChannels, Name="query") % Query
];
net = addLayers(net,layers);
net = connectLayers(net,'lstm_decoder','query');
net = connectLayers(net,'query','cross-attention/query');
plot(net);

0 Kommentare
Antworten (1)
Malay Agarwal
am 22 Jan. 2025
Bearbeitet: Malay Agarwal
am 22 Jan. 2025
The reason you get this error is because there is a cycle in your neural network. To be able to train the network, MATLAB needs to be able to perform a topological sort on the model graph (the plot that you are seeing). When the model graph has a cycle, such a topological sort is not possible, which is why MATLAB gives you the error.
The cycle exists since the output of the fully-connected layer named "query" is connected with the input of the attention layer named "cross-attention", while all the layers are connected one after the another similar to a traditional neural network.
Looking at the code, it seems to me that you are trying to form an encoder-decoder style LSTM network with a cross-attention mechanism. In such scenarios, the encoder and the decoder are typically disconnected from each other and work on different inputs. For example, the input to the encoder could be English sentences while the input to the decoder could be the corresponding translation of those sentences in French. In such a set up, the key and value for the attention layer comes from output of the encoder and the query comes from an intermediate step in the decoder.
You can set up such a network in MATLAB as follows:
numChannels = 256;
numHeads = 8;
net = dlnetwork;
% Set up the decoder
% It has its own sequenceInputLayer
% Also the query for attentionLayer comes from the output of the decoder
layers = [
sequenceInputLayer(1, Name="decoder-input")
lstmLayer(64, 'OutputMode', 'sequence', Name='lstm-decoder')
layerNormalizationLayer(Name='layernorm_decoder')
fullyConnectedLayer(numChannels, Name="query")
attentionLayer(numHeads,Name="cross-attention")];
net = addLayers(net,layers);
% Set up the encoder
% It has its own sequenceInputLayer
layers = [
sequenceInputLayer(1, Name="encoder-input")
lstmLayer(64, 'OutputMode', 'sequence', Name='lstm-encoder')
layerNormalizationLayer(Name='key-value')
% Add a fully connected layer to extract the key
fullyConnectedLayer(numChannels,Name="fc-key")];
net = addLayers(net,layers);
% Connect the key layer to the attention layer's key input
net = connectLayers(net,"fc-key","cross-attention/key");
% Add a fully-connected layer to extract the value
net = addLayers(net, fullyConnectedLayer(numChannels,Name="fc-value"));
% Connect the added layer with the key-value layer to extract the value
net = connectLayers(net,"key-value","fc-value");
% Connect the value layer to the attention layer's value input
net = connectLayers(net,"fc-value","cross-attention/value");
plot(net)
Notice how there are no cycles in the network. You will also notice that there are no errors reported when you call analyzeNetwork on this network.
Please refer to the following resources for more information:
Hope this helps!
0 Kommentare
Siehe auch
Kategorien
Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!
