Group K-fold partitioning a dataset

19 Ansichten (letzte 30 Tage)
Ivan Abraham
Ivan Abraham am 31 Jul. 2018
Beantwortet: Jaimin am 9 Jan. 2025
The scikit-learn package in Python has a Group K-Fold function that allows you to split the data-set into test/train folds while ensuring the same "group" is not present in different folds. This is useful for example in studies where the same subject/person generates multiple data-points and we want to make sure the samples/data-points belonging to the same subject don't appear in both the training and testing folds.
I was wondering if MATLAB has a way to do this or enable this option in cvpartition function or in some other way. The default options only seem to preserve relative class-sizes.

Antworten (1)

Jaimin
Jaimin am 9 Jan. 2025
While MATLAB does not offer a built-in function exactly like scikit-learn's GroupKFold, you can achieve similar results by manually creating your own group-based cross-validation partitions.
Here is how you can do it:
  1. Determine the unique groups in your dataset.
  2. Randomly shuffle these groups and then split them into k folds.
  3. Assign each data point to a fold based on its group.
% Sample data
data = rand(100, 5); % 100 samples, 5 features
labels = randi([0, 1], 100, 1); % Binary labels
groups = randi([1, 20], 100, 1); % 20 unique groups
% Number of folds
k = 5;
% Get unique groups
uniqueGroups = unique(groups);
% Shuffle groups
shuffledGroups = uniqueGroups(randperm(length(uniqueGroups)));
% Split groups into k folds
folds = cell(k, 1);
foldSize = ceil(length(shuffledGroups) / k);
for i = 1:k
startIdx = (i-1) * foldSize + 1;
endIdx = min(i * foldSize, length(shuffledGroups));
folds{i} = shuffledGroups(startIdx:endIdx);
end
% Create cross-validation partitions
cvIndices = zeros(size(groups));
for i = 1:k
testGroups = folds{i};
testIdx = ismember(groups, testGroups);
cvIndices(testIdx) = i;
end
for i = 1:k
testIdx = (cvIndices == i);
trainIdx = ~testIdx;
trainData = data(trainIdx, :);
trainLabels = labels(trainIdx);
testData = data(testIdx, :);
testLabels = labels(testIdx);
fprintf('Fold %d: Train on %d samples, Test on %d samples\n', i, sum(trainIdx), sum(testIdx));
end
For more information kindly refer following MathWorks documentation.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by