how to make cross attention use attentionlayer?

15 Ansichten (letzte 30 Tage)
dan
dan am 18 Dez. 2024
Kommentiert: dan am 19 Dez. 2024
I want to replace the dual-branch merge section of the model in the following link with cross-attention for fusion, but it's not successful. Is my operation incorrect? I have written an example, but I still don't understand how to embed it into the model in the link.
net one:(failure, loss dont down)
initialLayers = [
sequenceInputLayer(1, "MinLength", numSamples, "Name", "input", "Normalization", "zscore", "SplitComplexInputs", true)
convolution1dLayer(7, 2, "stride", 1)
];
stftBranchLayers = [
stftLayer("TransformMode", "squaremag", "Window", hann(64), "OverlapLength", 52, "Name", "stft", "FFTLength", 256, "WeightLearnRateFactor", 0 )
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="stft_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "stft_conv_1")
layerNormalizationLayer("Name", "stft_layernorm_1")
reluLayer("Name", "stft_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "stft_conv_2")
layerNormalizationLayer("Name", "stft_layernorm_2")
reluLayer("Name", "stft_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "stft_conv_3")
layerNormalizationLayer("Name", "stft_layernorm_3")
reluLayer("Name", "stft_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_3")
flattenLayer("Name", "stft_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_stft")
];
cwtBranchLayers = [
cwtLayer("SignalLength", numSamples, "TransformMode", "squaremag", "Name","cwt", "WeightLearnRateFactor", 0);
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="cwt_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "cwt_conv_1")
layerNormalizationLayer("Name", "cwt_layernorm_1")
reluLayer("Name", "cwt_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "cwt_conv_2")
layerNormalizationLayer("Name", "cwt_layernorm_2")
reluLayer("Name", "cwt_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "cwt_conv_3")
layerNormalizationLayer("Name", "cwt_layernorm_3")
reluLayer("Name", "cwt_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_3")
flattenLayer("Name", "cwt_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_cwt")
];
finalLayers = [
attentionLayer(4,"Name","attention")
layerNormalizationLayer("Name","layernorm")
fullyConnectedLayer(48,"Name","fc_1")
fullyConnectedLayer(numel(waveformClasses),"Name","fc_2")
softmaxLayer("Name","softmax")
];
dlLayers2 = dlnetwork(initialLayers);
dlLayers2 = addLayers(dlLayers2, stftBranchLayers);
dlLayers2 = addLayers(dlLayers2, cwtBranchLayers);
dlLayers2 = addLayers(dlLayers2, finalLayers);
dlLayers2 = connectLayers(dlLayers2, "conv1d", "stft");
dlLayers2 = connectLayers(dlLayers2, "conv1d", "cwt");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/key");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/value");
dlLayers2 = connectLayers(dlLayers2,"fc_cwt","attention/query");
my example:(is it right ?)
numChannels = 10;
numObservations = 128;
numTimeSteps = 100;
X = rand(numChannels,numObservations,numTimeSteps);
X = dlarray(X);
Y = rand(numChannels,numObservations,numTimeSteps);
Y = dlarray(Y);
numHeads = 8;
outputSize = numChannels*numHeads;
WQ = rand(outputSize, numChannels, 1, 1);
WK = rand(outputSize, numChannels, 1, 1);
WV = rand(outputSize, numChannels, 1, 1);
WO = rand(outputSize, outputSize, 1, 1);
Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO);
function Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO)
queries = WQ * X;
keys = WK * Y;
values = WV * Y;
A = attention(queries, keys, values, numHeads, 'DataFormat', 'CBT');
Z = WO * A;
end

Akzeptierte Antwort

Sahas
Sahas am 18 Dez. 2024
Bearbeitet: Sahas am 18 Dez. 2024
As per my understanding, you would like to replace the dual-branch merge section of the model with cross-attention. I went through your implementation and observed a few things. The implementation looks structurally correct but ensure the following points when using cross-attention with Classification technique as given in the documentation example:
  • Ensure that the dimensions of WQ, WK, WV, and WO align correctly with the input dimensions. The dimensions should match the expected sizes for matrices.
  • Ensure that the outputs from the fc_stft and fc_cwt layers are compatible with the input dimensions expected by your crossAttention function. As your layers end with fullyconnected layers with output size of 100, check the outputSize variable once again and if its matching the expected output size.
  • The output of fc_stft and fc_cwt layers should be connected to the inputs of crossAttention instead of directly to the attention layer.
  • Try using MATLAB's pagemtimes function for matrix layer multiplication of multi-dimensional arrays like queries, keys and values in the implementation. Here is the MathWorks documetation link for the same: https://www.mathworks.com/help/matlab/ref/pagemtimes.html
Hope this is beneficial!
  1 Kommentar
dan
dan am 19 Dez. 2024
Thank you, it was indeed a dimensionality issue. After incorporating the fully connected (fc) layer and making adjustments, I have resolved the problem.

Melden Sie sich an, um zu kommentieren.

Weitere Antworten (0)

Kategorien

Mehr zu Deep Learning Toolbox finden Sie in Help Center und File Exchange

Produkte


Version

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by