5 fold cross validation code for a dataset
Ältere Kommentare anzeigen
I want to split my data set into train set as well as test set using 5 fold cross validation .
Antworten (1)
Shubham
am 4 Sep. 2024
Hi Subhasmita,
In MATLAB, you can perform k-fold cross-validation to split your dataset into training and test sets. In k-fold cross-validation, the dataset is divided into k subsets (folds). The model is trained k times, each time using a different fold as the test set and the remaining folds as the training set.
Here's how you can perform 5-fold cross-validation in MATLAB:
% Load your dataset
load fisheriris % Example dataset
X = meas; % Features
y = species; % Labels
% Define the number of folds
k = 5;
% Create a cross-validation partition
cv = cvpartition(y, 'KFold', k);
% Initialize variable to store accuracy for each fold
accuracy = zeros(k, 1);
for i = 1:k
% Get the training and test indices for the current fold
trainIdx = training(cv, i);
testIdx = test(cv, i);
% Split the data into training and test sets for this fold
XTrain = X(trainIdx, :);
yTrain = y(trainIdx, :);
XTest = X(testIdx, :);
yTest = y(testIdx, :);
% Train the model on the training set
model = fitcsvm(XTrain, yTrain);
% Test the model on the test set
predictions = predict(model, XTest);
% Calculate accuracy for the current fold
accuracy(i) = sum(predictions == yTest) / length(yTest);
% Display accuracy for the current fold
fprintf('Fold %d Accuracy: %.2f%%\n', i, accuracy(i) * 100);
end
% Calculate and display the average accuracy across all folds
averageAccuracy = mean(accuracy);
fprintf('Average Accuracy: %.2f%%\n', averageAccuracy * 100);
Explanation
- Data Loading: We use the fisheriris dataset for demonstration, where X contains the features and y contains the labels.
- Cross-Validation Partition: We create a 5-fold partition using cvpartition with the option 'KFold', k.
- Loop Through Folds: For each fold, we:
- Extract training and test indices.
- Split the data into training and test sets.
- Train a support vector machine (SVM) model using fitcsvm.
- Predict on the test set and calculate accuracy.
4. Accuracy Calculation: We calculate and print the accuracy for each fold and the average accuracy across all folds.
Additional Notes
- You can replace fitcsvm with any other classifier that suits your needs.
- Ensure that your dataset is suitable for cross-validation, especially regarding class balance.
- You might also want to explore MATLAB's crossval function, which automates some parts of this process.
Kategorien
Mehr zu Support Vector Machine Classification finden Sie in Hilfe-Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!