How to add new classes to a neural network?

3 Ansichten (letzte 30 Tage)
Niculai Traian
Niculai Traian am 12 Sep. 2018

I made myself a network for flowers recognition. It's pretty much a copy of Alex net, but with some layers deleted. I trained it with 5 classes, but now i want to add more. How can i do that without retrain it from 0?

allImages = imageDatastore('D:\stuff machine learning\flowers', 'IncludeSubfolders', true,...
   'LabelSource', 'foldernames');
[trainingImages, testImages] = splitEachLabel(allImages, 0.8, 'randomize');
conv1 = convolution2dLayer(11,96,'Stride',4,'Padding',0); %290.5k neuroni
conv2 = convolution2dLayer(5,256,'Stride',1,'Padding',2); %7milioane neuroni
conv3 = convolution2dLayer(3,384,'Stride',1,'Padding',1);
conv4 = convolution2dLayer(3,384,'Stride',1,'Padding',1);
conv5 = convolution2dLayer(3,256,'Stride',1,'Padding',1);
layers = [...
    imageInputLayer([227 227 3]);
    conv1;
    reluLayer('Name','relu1');
    maxPooling2dLayer(3,'Name','pool1','Stride',2);
    conv2;
    reluLayer('Name','relu2');
    maxPooling2dLayer(3,'Name','pool2','Stride',2);
    conv3;
    reluLayer('Name','relu3');
    conv4;
    reluLayer('Name','relu4');
    conv5;
    reluLayer('Name','relu5');
    maxPooling2dLayer(3,'Name','pool5','Stride',2);
    fullyConnectedLayer(4096,'Name','fc6');
    reluLayer('Name','relu6');
    dropoutLayer('Name','drop6');
    fullyConnectedLayer(4096,'Name','fc7');
    reluLayer('Name','relu7');
    dropoutLayer('Name','drop7');
    fullyConnectedLayer(5,'Name','fc8');
    softmaxLayer('Name','prob');
    classificationLayer('Name','output');]
opts = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.001, ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropFactor', 0.1, ...
    'LearnRateDropPeriod', 10, ...
    'L2Regularization', 0.008, ...
    'MaxEpochs', 30, ...
    'MiniBatchSize', 40, ...
    'ValidationData',testImages, ...
    'Verbose', true,...
    'Plot','training-progress');
testImages.ReadFcn = @readFunctionTrain1;
trainingImages.ReadFcn = @readFunctionTrain1;
%antrenarea retelei
myNet = trainNetwork(trainingImages, layers, opts);
[YPred,probs] = classify(myNet,testImages);
accuracy = mean(YPred == testImages.Labels)
idx = randperm(numel(testImages.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(testImages,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");
end

This is the network

  1 Kommentar
Balakrishnan Rajan
Balakrishnan Rajan am 16 Okt. 2018
I am trying to do the same thing. Theoretically this should be done by changing the dimension of the Weights matrix, Bias vector and the OutputSize of the fully connected layer and the OutputSize of the classoutput layer and add the new category label to the Classes object. However, these properties are set to read-only.
Peter Gadfort provided a solution in this thread. However, I cant change the OutputSize as this is still a read-only property. If you do find a solution, please post it.
The code I am trying is this:
% Adding new classes to a trained net
%%Create an editable net object
load('BestNet.mat')
TempNet = net.saveobj;
%%Edit the properties of the fully connected layer
FCLayer = TempNet.Layers(142,1);
FCOutputSize = FCLayer.OutputSize;
FCLayer.OutputSize = FCOutputSize+1;
FCWeights = FCLayer.Weights;
FCWsize = size(FCWeights)
FCLayer.Weights = rand(FCWsize(1)+1, FCWsize(2));
FCLayer.Weights(1:FCWsize(1),:) = FCWeights;
FCBias = FCLayer.Bias;
FCLayer.Bias = rand(size(FCBias)+1);
FCLayer.Bias(1:size(FCBias)) = FCBias;
%%Edit the properties of the output layer
OutputLayer = TempNet.Layers(144,1);
OLOutputSize = OutputLayer.OutputSize;
OutputLayer.OutputSize = OLOutputSize + 1;
OLClasses = OutputLayer.Classes;
OLClasses(size(OLClasses)+1) = 'Obstructed';
%%Make this the net
net = load.obj(TempNet);
The pretrained net that I am using is the GoogLeNet derivative with the last three layers changed to a fully connected layer, a softmax layer followed by a crossentropy loss. I am adding a new class called "obstructed". Alphabetically sorted, this is the last class which is why I add the new elements to the end of the older elements.

Melden Sie sich an, um zu kommentieren.

Antworten (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by