how to get confusion matrix data from patternet function with cross validation?

3 Ansichten (letzte 30 Tage)
is the following program correct? I am trying to do 10 cross validation ...
feat = sig_feat(:,1:9);
% features
labels = sig_feat(:,10)
%labels
fold = cvpartition(labels,'kfold',10);
r=1;
net = patternnet(6);
while r<=10
trainIdx=fold.training(r); testIdx=fold.test(r);
xtrain=feat(trainIdx,:); ytrain=labels(trainIdx);
xtest = feat(testIdx,:); ytest = labels(testIdx);
train_data=[ytrain xtrain];
test_data = [ytest xtest];
net = train(net,xtrain',ytrain');
%view(net)
y_pred = net(xtest');
perf = perform(net,ytest,y_pred');
%%
[c,cm,ind,per] = confusion(ytest',y_pred);
confmat{r}=cm;
r=r+1;
end

Antworten (1)

Sai Pavan
Sai Pavan am 15 Apr. 2024
Hello,
I understand that you want to know whether your implementation of getting the confusion matrix from pattern recognition function with cross validation is correct. Your approach to implementing 10-fold cross-validation with a neural network using "patternnet" function and calculating the confusion matrix for each fold is mostly correct. However, here are a few adjustments needed to ensure the code runs smoothly and provides accurate results:
1. Neural networks in MATLAB expect input features 'x' as a floating-point matrix with variables in columns and observations in rows, and targets 'y' as categorical vectors or matrices. We have to ensure that the labels ('ytrain' and 'ytest') are in the correct format. For classification tasks, targets are often formatted as one-hot encoded vectors.
2. The "confusion" function expects targets and outputs in specific formats, usually with categorical labels represented as columns in a matrix for multi-class classification.
3. It's good practice to initialize the neural network inside the loop for cross-validation to ensure the network is reinitialized for each fold, preventing data leakage between folds.
Please refer to the below code snippet with the modification mentioned above:
feat = sig_feat(:,1:9); % Features
labels = sig_feat(:,10); % Labels
confmat = cell(1,10); % Initialize cell array to store confusion matrices
fold = cvpartition(labels,'KFold',10);
for r = 1:10
trainIdx = fold.training(r);
testIdx = fold.test(r);
xtrain = feat(trainIdx,:);
ytrain = labels(trainIdx);
xtest = feat(testIdx,:);
ytest = labels(testIdx);
% Convert labels to categorical if they are not already
ytrainCategorical = dummyvar(ytrain)'; % Convert to dummy variables and transpose
ytestCategorical = dummyvar(ytest)';
% Initialize and train the neural network
net = patternnet(6);
[net, tr] = train(net, xtrain', ytrainCategorical);
y_pred = net(xtest');
perf = perform(net, ytestCategorical, y_pred);
% Calculate confusion matrix
[c,cm,ind,per] = confusion(ytestCategorical, y_pred);
confmat{r} = cm;
end
Please refer to the below documentation to learn more about "dummyvar" function: https://www.mathworks.com/help/stats/dummyvar.html
Hope it helps!

Kategorien

Mehr zu Pattern Recognition and Classification finden Sie in Help Center und File Exchange

Produkte


Version

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by