Weight Tying for Layers in a CNN model

3 Ansichten (letzte 30 Tage)
Junaid Farooq
Junaid Farooq am 16 Jun. 2024
Kommentiert: Samuel Somuyiwa am 18 Jun. 2024
I am building a Deep Learning network looks like the figure below. I want to tie the weights of : Conv1 to Conv2; Conv3 to Conv6 and Conv4 to Conv5. Is it possible to do so in MATLAB? If yes, please help me with the code.
  2 Kommentare
Junaid Farooq
Junaid Farooq am 16 Jun. 2024
I tried this:
imageSize = [31 31 1];
encoderLayer1 = [
dlhdl.layer.sliceLayer(Name="slice",Groups=2,GroupID=2)
convolution2dLayer(3,32,"Padding",1,"WeightsInitializer","he","Name","conv1")
convolution2dLayer(3,32,"Padding",1,"WeightsInitializer","he","Name","conv3")
additionLayer(2,"Name","add");
];
encoderLayer2 = [
convolution2dLayer(3,32,"Padding",1,"WeightsInitializer","he","Name","conv2")];
layers = [imageInputLayer(inputSize, Normalization="none")
weightTyingEncoderLayer1(encoderLayer1,encoderLayer2)];
net = dlnetwork(layers);
Junaid Farooq
Junaid Farooq am 16 Jun. 2024
It gives following error:
Error using dlnetwork/initialize (line 558)
Invalid network.
Error in dlnetwork (line 167)
net = initialize(net, dlX{:});
Caused by:
Layer 'recursiveNet': Error using the predict function in layer weightTyingEncoderLayer1. The
function threw an error and could not be executed.
Error using convolution2dLayer (line 1)
Invalid argument at position 3. Function requires exactly 2 positional input(s).
Error in weightTyingEncoderLayer1/predict (line 36)
Y1 = convolution2dLayer(Y,this.Encoder.Learnables.Value{1}',this.Conv1);

Melden Sie sich an, um zu kommentieren.

Antworten (2)

Sanjana
Sanjana am 17 Jun. 2024
Hi Junaid,
Weight Tying is supported in MATLAB using Nested Layer. I have observed in your code for the "WeightTyingEncoderLayer1.m", that you have not declared and defined the "intialize" function, which is necessary initializing the learnables and parameters of the layers whose weights need to be tied, based on the "numChannels" of the layers chosen for this purpose.
The Error using "convolution2dLayer", is due to the reason that "this.Conv1" is not initialized.
To resolve the errors and implement Weight Tying Using Nested Layer, refer to the following example,
Also, refer to the following link for more information on creating custom "Nested Layer", and specifically refer to the "Custom Layer Template" section to understand how to create a custom nested layer in MATLAB.
  1 Kommentar
Junaid Farooq
Junaid Farooq am 17 Jun. 2024
Thank you Sanjana. I already read the articles provided in links. I would appreciate if you can help me out by writing a code for the diagram given above.

Melden Sie sich an, um zu kommentieren.


Samuel Somuyiwa
Samuel Somuyiwa am 17 Jun. 2024
See attached weightTyingAutoEncoder layer example. The layer follows on from the example in the link that Sanjana shared earlier.
See an example of how to use the layer below. Note that I couldn't understand what dlhdl.layer.sliceLayer in your example was meant to do or whether it was used correctly in that example. So, I have not included it in the example below
inputSize = [32 32 32];
layers = [
imageInputLayer(inputSize,Normalization="none")
weightTyingAutoEncoderLayer
];
net = dlnetwork(layers);
%%
X = dlarray(rand(inputSize),'SSCB');
Y = predict(net, X);
  2 Kommentare
Junaid Farooq
Junaid Farooq am 17 Jun. 2024
Thanks. As per your code will Conv1 and Conv3 get the input from Input Layer.
If we add Relu in between Convolution layers will the values of below code change:
decoder.Learnables.Value{2} = dlarray(this.DecoderBias1);
decoder.Learnables.Value{4} = dlarray(this.DecoderBias2);
decoder.Learnables.Value{6} = dlarray(this.DecoderBias3);
Samuel Somuyiwa
Samuel Somuyiwa am 18 Jun. 2024
Yes, the convolution layers get the input from the input layer.
No, adding RELU layer between the convolution layers will not change the values of the code.

Melden Sie sich an, um zu kommentieren.

Kategorien

Mehr zu Image Data Workflows finden Sie in Help Center und File Exchange

Produkte


Version

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by