Filter löschen
Filter löschen

How to change output classes for image classification in alexnet?

2 Ansichten (letzte 30 Tage)
Hello everyone
I'm trying to implement the following example:
but I got this error:
Error using trainNetwork (line 170)
Invalid training data. The output size (28224) of the last layer does not match the number of classes (24).
and this is my code:
tic
clear;
clc;
outputFolder = fullfile('Spilt_Dataset');
rootFolder = fullfile(outputFolder , 'Categories');
categories = {'front_or_left' , 'front_or_right','hump' , 'left_turn','narrow_from_left' , 'narrows_from_right',...
'no_horron' , 'no_parking','no_u_turn' , 'overtaking_is_forbidden','parking' , 'pedestrian_crossing',...
'right_or_left' , 'right_turn','rotor' , 'slow','speed_30' , 'speed_40',...
'speed_50','speed_60','speed_80','speed_100','stop','u_turn'};
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','Foldernames');
tbl = countEachLabel(imds);
minSetCount = min(tbl{:,2});
imds = splitEachLabel(imds, minSetCount, 'randomize');
countEachLabel(imds);
front_or_left = find(imds.Labels == 'front_or_left',1);
front_or_right = find(imds.Labels == 'front_or_right',1);
hump = find(imds.Labels == 'hump',1);
left_turn = find(imds.Labels == 'left_turn',1);
narrow_from_left = find(imds.Labels == 'narrow_from_left',1);
narrows_from_right = find(imds.Labels == 'narrows_from_right',1);
no_horron = find(imds.Labels == 'no_horron',1);
no_parking = find(imds.Labels == 'no_parking',1);
no_u_turn = find(imds.Labels == 'no_u_turn',1);
overtaking_is_forbidden = find(imds.Labels == 'overtaking_is_forbidden',1);
parking = find(imds.Labels == 'parking',1);
pedestrian_crossing = find(imds.Labels == 'pedestrian_crossing',1);
right_or_left = find(imds.Labels == 'right_or_left',1);
right_turn = find(imds.Labels == 'right_turn',1);
rotor = find(imds.Labels == 'rotor',1);
slow = find(imds.Labels == 'slow',1);
speed_30 = find(imds.Labels == 'speed_30',1);
speed_40 = find(imds.Labels == 'speed_40',1);
speed_50 = find(imds.Labels == 'speed_50',1);
speed_60 = find(imds.Labels == 'speed_60',1);
speed_80 = find(imds.Labels == 'speed_80',1);
speed_100 = find(imds.Labels == 'speed_100',1);
stop = find(imds.Labels == 'stop',1);
u_turn = find(imds.Labels == 'u_turn',1);
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
net = googlenet;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
numClasses = numel(categories(imdsTrain.Labels))
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(augimdsTrain,lgraph,options);

Akzeptierte Antwort

Srivardhan Gadila
Srivardhan Gadila am 16 Apr. 2021
If the classes listed from line "front_or_left = find(imd...." in the above code are the only classes which are 24 in total then the issue might be that the value returned by
numClasses = numel(categories(imdsTrain.Labels))
is probably 28224 and not 24. If that's the case then specify the value 24 directly in the line
newLearnableLayer = fullyConnectedLayer(24, ...
and try running the code.
You can try debugging the issue by checking the layer activation size by using analyzeNetwork function.
If the network is fine then the issue might be w.r.t your data i.e., labels itself.
If you still face the issue then Contact MathWorks Support.

Weitere Antworten (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by