model = helper.downloadPretrainedDeepLabv3Plus;
net = model.net;
imageURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip';
labelURL = 'http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip';
outputFolder = fullfile(tempdir,'CamVid');
labelsZip = fullfile(outputFolder,'labels.zip');
imagesZip = fullfile(outputFolder,'images.zip');
if ~exist(labelsZip, 'file') || ~exist(imagesZip,'file')
mkdir(outputFolder)
disp('Downloading 16 MB CamVid dataset labels...');
websave(labelsZip, labelURL);
unzip(labelsZip, fullfile(outputFolder,'labels'));
disp('Downloading 557 MB CamVid dataset images...');
websave(imagesZip, imageURL);
unzip(imagesZip, fullfile(outputFolder,'images'));
end
imgDir = fullfile(outputFolder,'images','701_StillsRaw_full');
imds = imageDatastore(imgDir);
classes = [
"Sky"
"Building"
"Pole"
"Road"
"Pavement"
"Tree"
"SignSymbol"
"Fence"
"Car"
"Pedestrian"
"Bicyclist"
];
labelIDs = helper.camvidPixelLabelIDs;
labelDir = fullfile(outputFolder,'labels');
pxds = pixelLabelDatastore(labelDir,classes,labelIDs);
tbl = countEachLabel(pxds);
[imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = helper.partitionCamVidData(imds,pxds);
numClasses = numel(classes);
lgraph = layerGraph(net);
convLayer = convolution2dLayer([1 1], numClasses,'Name', 'node_398');
lgraph = replaceLayer(lgraph,"node_398",convLayer);
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq;
pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights);
lgraph = replaceLayer(lgraph,"labels",pxLayer);
analyzeNetwork(lgraph);
xTrans = [-10 10];
yTrans = [-10 10];
augmenter = imageDataAugmenter('RandXReflection',true, 'RandXTranslation',xTrans, 'RandYTranslation',yTrans);
dsTrain = randomPatchExtractionDatastore(imdsTrain,pxdsTrain,[513 513],'PatchesPerImage',8, 'DataAugmentation', augmenter);
dsVal = randomPatchExtractionDatastore(imdsVal,pxdsVal,[513 513],'PatchesPerImage',8);
options = trainingOptions('sgdm', ...
'LearnRateSchedule','piecewise',...
'LearnRateDropPeriod',10,...
'LearnRateDropFactor',0.3,...
'Momentum',0.9, ...
'InitialLearnRate',1e-3, ...
'L2Regularization',0.005, ...
'ValidationData',dsVal,...
'MaxEpochs',6, ...
'MiniBatchSize',2, ...
'Shuffle','every-epoch', ...
'CheckpointPath', tempdir, ...
'VerboseFrequency',2,...
'Plots','training-progress',...
'ValidationPatience', 4,...
'ExecutionEnvironment','auto');