Shapley Values for svm classification with 10 Kfold

Hi,
I would like to compute Shapley Values with SVM tested with Kfold= 10. Here my code:
load fisheriris
inds = ~strcmp(species,'setosa');
X = meas(inds,3:4);
y = species(inds);
SVMModel = fitcsvm(X,y,Kfold=10);
explainer = shapley(SVMModel);
Error using shapley
Blackbox model must be a classification model, regression model, or function handle
How I can solve this?

Antworten (1)

the cyclist
the cyclist am 25 Feb. 2023
Bearbeitet: the cyclist am 25 Feb. 2023
Full disclosure: I am a bit new to this myself.
I believe the reason that your code does not work is that technically, your syntax is creating 10 different SVM models, because of the 10-fold cross-validation. shapley is expecting a single, blackbox model.
I did not see a way to take the output of fitcsvm (with crossvalidation), and somehow get at a single model output that shapley() will accept. I think that may be possible. But, an alternative is to create the individual folds by using the cvpartition function, and then running each of the SVM models "manually". Then, you can run shapley on each of the individual model outputs, and average the Shapley values across the folds.
Here is some code that gives the idea-- created by ChatGPT! You'll need to swap in your actual data, of course.
I realize this is a pain, and there may be another way. (But I also realize that questions like yours often don't get too many answers here, so I wanted to share what I know.)
% Load your dataset and split it into k subsets
k = 10;
cv = cvpartition(numObservations,'KFold',k);
% Train an SVM model on each fold and compute Shapley values
for i = 1:k
% Get the training and test data for this fold
trainData = data(cv.training(i),:);
testData = data(cv.test(i),:);
trainLabels = labels(cv.training(i));
% Train an SVM model on the training data
model = fitcsvm(trainData,trainLabels);
% Compute the Shapley values for this model
shapValues{i} = shapley(model,testData); % I don't think this is quite right for getting the model shapley values. Need to check this.
end
% Average the Shapley values across folds
meanShapValues = mean(cat(3,shapValues{:}),3);

3 Kommentare

That code from ChatGPT was giving errors when I tried to put in some simple data. I think it also has some syntactical "misunderstandings". Here is some code, using random data, that works:
% Made-up data
numObservations = 1000;
x1 = randn(numObservations,1);
x2 = randn(numObservations,1);
labels = string(binornd(1,0.2,numObservations,1));
% Put the explanatory data into a table
data = table(x1,x2);
% Split the data into k subsets
numberOfFolds = 10;
cv = cvpartition(numObservations,'KFold',numberOfFolds);
% Train an SVM model on each fold and compute Shapley values
for i = 1:numberOfFolds
% Get the training and test data for this fold
trainData = data(cv.training(i),:);
testData = data(cv.test(i),:);
trainLabels = labels(cv.training(i));
% Train an SVM model on the training data, for this fold
modelThisFold = fitcsvm(trainData,trainLabels);
% Create the Shapley object for this fold (on the test data)
shapleyObject{i} = shapley(modelThisFold,testData);
% Extract the Shapley values for the first test observation, for this fold
explainerFirstDatapoint{i} = fit(shapleyObject{i},testData(1,:));
shapleyValuesFirstDatapoint{i} = explainerFirstDatapoint{i}.ShapleyValues;
end
The only way I found to get the Shapley value output was in a table for each fold. This is a bit annoying to take the average of, across folds. But the data are there.
Also note that this is only calculating the Shapley value for the first observation of the test data.
Thank you so much!!
With your code I am able to evaluate Shapley values for the whole test data with a simple cycle.
You made my day.
the cyclist
the cyclist am 25 Feb. 2023
Bearbeitet: the cyclist am 25 Feb. 2023
I'm really happy to have helped. I do feel that what I wrote (aided by ChatGPT) is still pretty awkward. If you found a more elegant or efficient approach, please share it here.
I assume by "cycle", you mean you used for loops, which I guess is needed over both the folds and the data points.

Melden Sie sich an, um zu kommentieren.

Produkte

Version

R2022b

Gefragt:

am 24 Feb. 2023

Bearbeitet:

am 25 Feb. 2023

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by