How can I transfer the model parameters of a well-trained NN to another one?
9 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
Wanli Wen
am 24 Nov. 2019
Beantwortet: Divya Gaddipati
am 5 Dez. 2019
I have two NNs, i.e., net_1 and net_2, where net_1 is not trained and net_2 has been well trained. Now I want to transfer the knowledge of net_2 to net_1, such that net_1 can be used well as net_2. So I have got the following code. However, after setting the weights and bias of net_1 to those of net_2, I find that the net_1 behaves very very bad, e.g., net_2(-2) = 3.999, net_1(-2)=32.249. Here, net_1 is expected to output a value that is very similar with net_2. May anone please tell me that is there anything wrong with my code? Thanks.
(Please note that I do not want to use the operation net_1 = net_2 to achieve this purpose.)
clear all
%%
% Task: To fit a non-linear function f(x) = x.^2
%%
D=1e4; % no. of training sample
layers_neurons=[64];
%% Net 1: no training network
net_1 = feedforwardnet(layers_neurons);
[data1,target2] = gen_data_sample(10);
net_1 = configure(net_1, data1, target2);
%% Net 2: well training network
[data2,target2] = gen_data_sample(D);
net_2 = feedforwardnet(layers_neurons); % doc feedforwardnet for more details
net_2 = configure(net_2, data2, target2);
net_2 = train(net_2,data2, target2); % , 'useGPU', 'yes', 'useparallel', 'yes'
%% Transfer the knowledge of Net 2 to Net 1
net_1.IW = net_2.IW;
net_1.LW = net_2.LW;
net_1.b = net_2.b;
%% Test and Compare Net 1 and Net 2
net_1(-2)
net_2(-2)
%%
function [input,output] = gen_data_sample(D)
%%
input = -20+(20-(-20))*rand(1, D);
output = input.^2;
end
0 Kommentare
Akzeptierte Antwort
Divya Gaddipati
am 5 Dez. 2019
Before you assign weights of “net_2” to “net_1”, initialize net_1 to net_2 using the init function
net_1 = init(net_2);
This would resolve your issue.
Additionally, you can also remove the configuring part of net_1 (i.e., line 10 in your code), which might not be required if you are using init.
For more information on configure and init, refer to the below link:https://www.mathworks.com/help/deeplearning/ug/create-configure-and-initialize-multilayer-neural-networks.html#bss330n-3
Hope this helps!
0 Kommentare
Weitere Antworten (0)
Siehe auch
Kategorien
Mehr zu Sequence and Numeric Feature Data Workflows finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!