Cross-validation improvement
2 Ansichten (letzte 30 Tage)
Ältere Kommentare anzeigen
Hi, All
I want to improve cross-validation results.
so, I fellow this suggestion from @Greg Heath use multiple nets to improve the cross-validation results.
a. For each i of i =1:k design multiple nets differing by the assignment of random initial weights. Discard those with poor performance and average the performance of the rest
is this code right?
clear all;
close all;
clc
counter=0
for h= 5:20
for j= 5:20
%%input file
xp = predictors;
xt = table2array(predictors);
x= xt' % input
t = response' % target
trainFcn = 'trainbr'; % Scaled conjugate gradient backpropagation.
hiddenLayerSize = [h,j]
net = fitnet(hiddenLayerSize, trainFcn)
sum_rmse =0
fid=fopen('sum-r.txt','w')
KFolds = 10 % K value
cvp = cvpartition(size(response, 1), 'KFold', KFolds) %gernater cvp
for i= 1:KFolds
trainIdx =cvp.training(i); % index of training data
testIdx = cvp.test(i); % index of validation data
trInd=find(trainIdx)
testInd=find(testIdx)
xtrain=x(:,trainIdx);
ytrain=t(:,trainIdx);
xtest=x(:,testIdx);
ytest=t(:,testIdx);
net.layers{1}.transferFcn ='tansig'; % logsig
net.layers{2}.transferFcn ='tansig';
net.layers{3}.transferFcn ='purelin';
net.divideFcn = 'divideind';
net.divideParam.trainInd=trInd;
net.divideParam.testInd=tstInd;
net.trainParam.showWindow = 0
net.performFcn = 'mse' % MSE
%rng('default')
nets{i} = train(net,xtrain,ytrain) % trains a network net according to net.trainFcn and net.trainParam. tr: returns a training record
end
%% multiple nets
for i= 1: KFolds
neti=nets{i}
yPred = neti(xtest)
ypt=neti(xtrain)
y = neti(x) % test the molde net % Estimate the targets using the trained network.
Train_RMSE = sqrt(sum(( ypt - ytrain ).^2) / numel(ytrain))
%da = y-t;
vrmse = sqrt(sum((yPred-ytest).^2) / numel(ytest))
train_t(i)=Train_RMSE
ttrr(i)=vrmse
R= corrcoef(ytest,yPred)
Rv(i)=R(1,2)
%avrg_rmse = mean(testrmse)
trmse= mse(neti,ytest,yPred)
sum_rmse = sum_rmse +vrmse
counter=counter+1
figname = strcat(num2str(counter),'.mat')
save(figname);
fprintf(fid,'%6.2f %12.2f\n\n',sum_rmse/10);
%plotregression(yTest,yPred,'Test')
end
fclose(fid);
average_rmse(h,j)=sum_rmse/10;
%accuracy=mean(tata)
cvrmse(h,j)=mean(ttrr)
ttrmse(h,j)=mean(train_t)
R_value(h,j) = mean(Rv)
end
end
R_value=R_value'
ttrmse= ttrmse'
cvrmse= cvrmse'
save ('test.txt','cvrmse','-ascii')
save('train.txt','ttrmse','-ascii')
save('r.txt','R_value','-ascii')
0 Kommentare
Antworten (0)
Siehe auch
Kategorien
Mehr zu Statistics and Machine Learning Toolbox finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!