How to plot averaged ROC curve?

9 Ansichten (letzte 30 Tage)
Ishfaque Ahmed
Ishfaque Ahmed am 20 Apr. 2022
Kommentiert: Adam Danz am 25 Apr. 2022
I am trying to plot ROC curve for my model for multiple iterations. The curve are not at same locations so I want to plot one averaged ROC from all 10 ROC curves. Please suggest me the solution.
  1 Kommentar
Chunru
Chunru am 21 Apr. 2022
You can interpolate each curve on the same grid and then perform average.

Melden Sie sich an, um zu kommentieren.

Akzeptierte Antwort

Chunru
Chunru am 21 Apr. 2022
% Create sample data
numPoints = 50;
nCurves = 10;
x = sort(rand(numPoints, nCurves));
y = (sort(rand(numPoints, nCurves))).^(1/4);
plot(x, y);
grid on;
hold on;
% same grid
x0 = linspace(0, 1, 100);
% interpolation
yinterp = zeros(length(x0), nCurves);
for i=1:nCurves
yinterp(:, i) = interp1(x(:,i), y(:,i), x0, 'linear', 'extrap');
end
% Now average together
meany = mean(yinterp, 2);
% Now plot
hold on;
plot(x0, meany, 'LineWidth', 2);

Weitere Antworten (1)

Image Analyst
Image Analyst am 21 Apr. 2022
Bearbeitet: Image Analyst am 21 Apr. 2022
Try this:
% Create sample data because the original poster didn't upload theirs.
numPoints = 30;
x1 = sort(rand(1, numPoints));
x2 = sort(rand(1, numPoints));
x3 = sort(rand(1, numPoints));
x4 = sort(rand(1, numPoints));
x5 = sort(rand(1, numPoints));
x6 = sort(rand(1, numPoints));
x7 = sort(rand(1, numPoints));
x8 = sort(rand(1, numPoints));
x9 = sort(rand(1, numPoints));
x10 = sort(rand(1, numPoints));
y1 = sort(rand(1, numPoints));
y2 = sort(rand(1, numPoints));
y3 = sort(rand(1, numPoints));
y4 = sort(rand(1, numPoints));
y5 = sort(rand(1, numPoints));
y6 = sort(rand(1, numPoints));
y7 = sort(rand(1, numPoints));
y8 = sort(rand(1, numPoints));
y9 = sort(rand(1, numPoints));
y10 = sort(rand(1, numPoints));
plot(x1, y1, '-');
hold on;
plot(x2, y2, '-');
plot(x3, y3, '-');
plot(x4, y4, '-');
plot(x5, y5, '-');
plot(x6, y6, '-');
plot(x7, y7, '-');
plot(x8, y8, '-');
plot(x9, y9, '-');
plot(x10, y10, '-');
grid on;
hold on;
%========================================================================
% Since you have your own data you'd start here
% and NOT create the sample data above.
allx = sort([x1,x2,x3,x4,x5,x6,x7,x8,x9,x10], 'ascend');
% Then interpolate all the other curves so they're on a common x axis.
y1a = interp1(x1, y1, allx);
y2a = interp1(x2, y2, allx);
y3a = interp1(x3, y3, allx);
y4a = interp1(x4, y4, allx);
y5a = interp1(x5, y5, allx);
y6a = interp1(x6, y6, allx);
y7a = interp1(x7, y7, allx);
y8a = interp1(x8, y8, allx);
y9a = interp1(x9, y9, allx);
y10a = interp1(x10, y10, allx);
% Get all y together in one matrix.
allY = [y1a;y2a;y3a;y4a;y5a;y6a;y7a;y8a;y9a;y10a];
% Find out how many curves have valid, non-nan values at each x location.
counts = sum(~isnan(allY), 1);
% Now set nans to zero so we can sum the values and not get a nan if one of the curves is nan for some x value.
allY(isnan(allY)) = 0;
% Since some y are nan (which happens outside the x range where they were originally defined)
% we can't use mean(ally, 1) to get the mean value because we'd be averaging in zeros.
% So we need to sum the ally array vertically to get the sum of the non-nan values,
% and then sum the counts array vertically to find out
% how many signals were not nan for those x values.
% Then we can divide the sum by the counts to get the true mean.
meany = sum(allY, 1) ./ sum(counts, 1);
% Now plot the mean as a thick black curve.
hold on;
plot(allx, meany, 'k-', 'LineWidth', 4);
title('Thick black line is the mean of all curves')
Note how the plot gets a little wiggly near the ends as the number of valid curves (non-nan values) gets fewer and so the mean gets closer to the valid remaining curves. For example let's say the after x = 0.9 there are only 5 curves with non-nan values, not the full 10. So there you'd want to average only 5 curves, not all 10. So in the picture above, see close to 1, only the yellow curve has valid x values out that far, so the mean will equal the yellow curve's y value there. It's for this reason that you can't just simply use the mean() function and you have to divide the sum by the count (because the count changes). Does that make sense?
  3 Kommentare
Image Analyst
Image Analyst am 21 Apr. 2022
Well I guess you could compute the standard deviation at every x location and then get two curves
  1. the average curve plus the locally varying standard deviation
  2. the average curve minus the locally varying standard deviation.
Then plot those curves. One will be above the mean curve and one will be below it. Where you have only one curve (at the outside ends) the standard deviation will be zero there of course.
Adam Danz
Adam Danz am 25 Apr. 2022
I wonder if curve fitting would useful. Then you could get error estimates of the fit parameters and plot the smooth fit and the range of error.

Melden Sie sich an, um zu kommentieren.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by