matlabのディー​プラーニングでは、な​ぜテストデータを使わ​ずにバリデーションデ​ータを使うのか

プログラミング初心者です。
下記リンクにつきまして、
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);
という一文がありますが、なぜ、テストデータを使わずにバリデーションデータを使うのでしょうか。
imdsValidationではなく、imdsTestだと納得できるのですが不思議です。
もしバリデーションデータを使うのであれば、テストデータは使わなくてもいいかご教示頂けますと幸いです。

 Akzeptierte Antwort

Kenta
Kenta am 12 Mär. 2019

3 Stimmen

単に、ここではバリデーションデータをテストデータと読み替えて問題ないと思います。また、以下のように、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2);
などとして、画像を訓練、バリデーション、テストデータに分けると良いかもしれません。
リンクの学習曲線のところでは、バリデーションデータを使います。
そして、最後のところで
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
とすると、テストデータで正答率を計算できます。ここで、optionsのところに
'ValidationPatience', 3
を追加すれば学習の早期終了ができます。「'ValidationPatience' の値は、ネットワークの学習が停止するまでに、検証セットでの損失が前の最小損失以上になることが許容される回数です。」
とあります。学習がある程度のところで限界が来たらそこで学習がストップするので学習時間を短縮できたり、過学習が抑えられる可能性があります。

11 Kommentare

ssk
ssk am 12 Mär. 2019
itakuraさま、いつもご回答いただきまして誠にありがとうございます。本例につきまして、バリデーションデータとテストデータを置き換えることができる旨、ご教示いただきありがとうございます。imdstestも使った例でも試してみます。また、'ValidationPatience', も利用してみたいと思います。
あわせて、下記リンクにつきましてitakura様宛に追加でご質問させていただきましたのでご覧頂けますと幸いです。
https://jp.mathworks.com/matlabcentral/answers/447586-cross-validation
ssk
ssk am 13 Mär. 2019
追加でのご質問失礼いたします。
[imdsTrain,imdsTest,imdsValidation] = splitEachLabel(imds,0.7,0.2);と仮にテストデータとバリデーションデータを分けた場合、クロスバリデーションは以前ご教示いただいた方法と同じものを使っても差し支えないでしょうか。それとも、imdstest向けのコードの変更が必要でしょうか。
Kenta
Kenta am 13 Mär. 2019
3つに分けたかったら少し改変する必要があります。
具体的には、9つ分のイメージデータストアを合わせたものをもう一度、imdsTrain, imdsValidationにわける必要があります。
ただ、今回のような結果だといままでどおりの交差検証のコードでひとまずやってみたらよいと思います。枚数を増やしてうまくいかなかったらまた対策を考える方向がよいかと思います。
ssk
ssk am 13 Mär. 2019
ご教示いただきありがとうございます。まずはもともとの交差検証のコードで進めたいと思います。念のため、9つ分のイメージデータストアを合わせたものをもう一度、imdsTrain, imdsValidationにわけるコードにつき、以前ご教示いただいたコードを基に自身で作成してみます。
ssk
ssk am 13 Mär. 2019
トレーニング、テスト、バリデーションの3つに分けたコードを試しに作成してみたのですが、以下のコードでご趣旨を反映できておりますでしょうか。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsValidation' ,'=', stname2,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
[imds11,imds12,imds13,imds14,imds15]...
= splitEachLabel(imds,0.2,0.2,0.2,0.2,'randomize');
imdsTest11 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds14.Files));
imdsTest11.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds14.Labels);
imdsTest12 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds15.Files));
imdsTest12.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds15.Labels);
imdsTest13 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds14.Files,imds15.Files));
imdsTest13.Labels = cat(1,imds11.Labels,imds12.Labels,imds14.Labels,imds15.Labels);
imdsTest14 = imageDatastore(cat(1,imds11.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest14.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
imdsTest15 = imageDatastore(cat(1,imds12.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest15.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
%% training for test data
accuracy=zeros(11,15);
for i3=11:15
stname3=sprintf('imdsTest%d',i3);
eval(['imdsTest' ,'=', stname3,';'])
%imdsTest.ReadFcn = @(filename)resize(filename);
i4=15-i+1;
stname4=sprintf('imds0%d',i4);
eval(['imdsValidation' ,'=', stname4,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
%%train network(中略)
[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
Kenta
Kenta am 14 Mär. 2019
i番目のループのなかで、トレーニングデータ(仮)をトレーニングデータとバリデーションデータに分けたらいいと思います。そして、バリデーションデータをテストデータ(ただ名前を変えるだけ)としてテストしたらいいです。
ある程度までロスが下がり切ったりしたら計算時間が冗長になるし、訓練データに過適合するのを防げます。ただ、たくさんの枚数をこなしたときに必ずしももこの操作が必要かどうかは不明です。1クラス100枚くらいで交差検証なしでやってみてはどうでしょうか。CPUで計算してもそこまで計算時間はかからないと思います。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsTest' ,'=', stname2,';'])
imdsTest.ReadFcn = @(filename)resize(filename);
%% training for test data
%imdstrainで訓練
%imdsvalidationをoptionsのなかのvalidationに指定
%imdstestでテスト
ssk
ssk am 14 Mär. 2019
Bearbeitet: ssk am 14 Mär. 2019
ありがとうございます!コードを試したところ無事に動きました。本コードにおけるクロスバリデーションのニュアンスの確認をしたいのですが、はじめに全ての画像をtrainingとして均等に10分割し、さらに10分割した画像をそれぞれtraining:validation = 8:2で分ける。このとき、testはvalidationと同視できるので、training:test = 8:2である。(つまり、本データの8割をtraining、2割をtest(validation)として使う。その後、組み合わせをかえてそれぞれの画像のaccuracyを調べて平均を取る。上記の認識でよろしいでしょうか?
以前あった例ですと、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2,0.1); で合計が100%ですが、今回の場合は、[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.8,0.2,0.2);で合計120%のような気もするのですが、例えば[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.6,0.2,0.2);のような形で修正する必要はないのでしょうか?
また、なぜテストデータとバリデーションデータを同視できるか理由をご存知でしたらご教示いただけますと幸いです。
Kenta
Kenta am 15 Mär. 2019
今回の場合は、1000枚あったとすると、900, 100に分けて、その900をさらに800と100に分けたイメージです。
もちろんおっしゃるように、1000枚を一気に8:1:1に分けても等価です。そのサイクルを10回して、それを平均すればいいです。
同視できるというのは、バリデーションデータと名付けられたデータでテストをしているので、その場合に限り、バリデーションデータをテストデータと読み替えて問題ないのでは?ということです。
本来、バリデーションデータとテストデータは異なったニュアンスを持っているものと思います。
ssk
ssk am 15 Mär. 2019
ニュアンスをご教示いただきありがとうございました!
覚えが悪く大変申し訳なのでもう一度確認しますが、(1,000枚の画像がある場合、)まず900(training), 100(test)に分けて、その後で900を更に800(training)、100(varidation)に分けるということですね。以上から、800(training)、100(test)、100(validation)になるということでしょうか。
上記コードですと、for loop内では以下のようになっておりますので
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
imstrain1の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds10
これを順に10回続けていって・・・
imstrain10の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds01
以上の平均を求めるという認識でよろしいでしょうか?
Kenta
Kenta am 17 Mär. 2019
はい、それで正しいと思います。
ssk
ssk am 17 Mär. 2019
ありがとうございます!

Melden Sie sich an, um zu kommentieren.

Weitere Antworten (0)

Kategorien

Mehr zu Deep Learning Toolbox finden Sie in Hilfe-Center und File Exchange

Gefragt:

ssk
am 7 Mär. 2019

Kommentiert:

ssk
am 17 Mär. 2019

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!