Cross-validation improvement

2 Ansichten (letzte 30 Tage)
Qiang Wang
Qiang Wang am 22 Sep. 2020
Bearbeitet: Qiang Wang am 5 Okt. 2020
Hi, All
I want to improve cross-validation results.
so, I fellow this suggestion from @Greg Heath use multiple nets to improve the cross-validation results.
a. For each i of i =1:k design multiple nets differing by the assignment of random initial weights. Discard those with poor performance and average the performance of the rest
is this code right?
clear all;
close all;
clc
counter=0
for h= 5:20
for j= 5:20
%%input file
xp = predictors;
xt = table2array(predictors);
x= xt' % input
t = response' % target
trainFcn = 'trainbr'; % Scaled conjugate gradient backpropagation.
hiddenLayerSize = [h,j]
net = fitnet(hiddenLayerSize, trainFcn)
sum_rmse =0
fid=fopen('sum-r.txt','w')
KFolds = 10 % K value
cvp = cvpartition(size(response, 1), 'KFold', KFolds) %gernater cvp
for i= 1:KFolds
trainIdx =cvp.training(i); % index of training data
testIdx = cvp.test(i); % index of validation data
trInd=find(trainIdx)
testInd=find(testIdx)
xtrain=x(:,trainIdx);
ytrain=t(:,trainIdx);
xtest=x(:,testIdx);
ytest=t(:,testIdx);
net.layers{1}.transferFcn ='tansig'; % logsig
net.layers{2}.transferFcn ='tansig';
net.layers{3}.transferFcn ='purelin';
net.divideFcn = 'divideind';
net.divideParam.trainInd=trInd;
net.divideParam.testInd=tstInd;
net.trainParam.showWindow = 0
net.performFcn = 'mse' % MSE
%rng('default')
nets{i} = train(net,xtrain,ytrain) % trains a network net according to net.trainFcn and net.trainParam. tr: returns a training record
end
%% multiple nets
for i= 1: KFolds
neti=nets{i}
yPred = neti(xtest)
ypt=neti(xtrain)
y = neti(x) % test the molde net % Estimate the targets using the trained network.
Train_RMSE = sqrt(sum(( ypt - ytrain ).^2) / numel(ytrain))
%da = y-t;
vrmse = sqrt(sum((yPred-ytest).^2) / numel(ytest))
train_t(i)=Train_RMSE
ttrr(i)=vrmse
R= corrcoef(ytest,yPred)
Rv(i)=R(1,2)
%avrg_rmse = mean(testrmse)
trmse= mse(neti,ytest,yPred)
sum_rmse = sum_rmse +vrmse
counter=counter+1
figname = strcat(num2str(counter),'.mat')
save(figname);
fprintf(fid,'%6.2f %12.2f\n\n',sum_rmse/10);
%plotregression(yTest,yPred,'Test')
end
fclose(fid);
average_rmse(h,j)=sum_rmse/10;
%accuracy=mean(tata)
cvrmse(h,j)=mean(ttrr)
ttrmse(h,j)=mean(train_t)
R_value(h,j) = mean(Rv)
end
end
R_value=R_value'
ttrmse= ttrmse'
cvrmse= cvrmse'
save ('test.txt','cvrmse','-ascii')
save('train.txt','ttrmse','-ascii')
save('r.txt','R_value','-ascii')

Antworten (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by