How could this code be vectorized?
1 Ansicht (letzte 30 Tage)
Ältere Kommentare anzeigen
Fredrik P
am 18 Jan. 2024
Beantwortet: Voss
am 18 Jan. 2024
How can I vectorize this code? I suspect that it can be solve using ndgrid and sub2ind as my last question was. But I just don't get how it should be done. (I'm aware that the vectorized code might very well be slower than the loop-based code. RIght now, I'm just hoping to learn how it should be done.)
close; clear; clc;
n1 = 400;
n2 = 9;
n3 = 2;
A1 = rand(n1, n2, n3);
A2 = rand(n1, n2, n3);
A3 = rand(n1, n2, n3);
[~, B] = max(A3, [], 3);
[C11, C21, C31] = version1(A1, A2, A3, B, n1, n2, n3);
% [C12, C22, C32] = version2(A1, A2, A3, B, n1, n2, n3);
% disp([isequal(C11, C12), isequal(C21, C22), isequal(C31, C32)]);
disp(timeit(@() version1(A1, A2, A3, B, n1, n2, n3)));
% disp(timeit(@() version2(A1, A2, A3, B, n1, n2, n3)));
function [C1, C2, C3] = version1(A1, A2, A3, B, n1, n2, n3)
C1 = nan(n1, n2);
C2 = C1;
C3 = C1;
for j2 = 1:n2
for j1 = 1:n1
C1(j1, j2) = A1(j1, j2, B(j1, j2));
C2(j1, j2) = A2(j1, j2, B(j1, j2));
C3(j1, j2) = A3(j1, j2, B(j1, j2));
end
end
end
0 Kommentare
Akzeptierte Antwort
Cris LaPierre
am 18 Jan. 2024
Bearbeitet: Cris LaPierre
am 18 Jan. 2024
I'd just have your max function return the linear index instead of the index. Then use that to index A1, A2, and A3
% your current solution
n1 = 400;
n2 = 9;
n3 = 2;
A1 = rand(n1, n2, n3);
A2 = rand(n1, n2, n3);
A3 = rand(n1, n2, n3);
[~, B] = max(A3, [], 3);
[C11, C21, C31] = version1(A1, A2, A3, B, n1, n2, n3);
% Function definition move to the bottom of script
Here is how you would do it using linear indexing
[~, B1] = max(A3, [], 3,'linear');
C111 = A1(B1);
C211 = A2(B1);
C311 = A3(B1);
Now compare the results to see if they are the same.
isequal(C11, C111)
isequal(C21, C211)
isequal(C31, C311)
As you can see, the results are equal.
% function moved to bottom of script
function [C1, C2, C3] = version1(A1, A2, A3, B, n1, n2, n3)
C1 = nan(n1, n2);
C2 = C1;
C3 = C1;
for j2 = 1:n2
for j1 = 1:n1
C1(j1, j2) = A1(j1, j2, B(j1, j2));
C2(j1, j2) = A2(j1, j2, B(j1, j2));
C3(j1, j2) = A3(j1, j2, B(j1, j2));
end
end
end
0 Kommentare
Weitere Antworten (1)
Voss
am 18 Jan. 2024
n1 = 400;
n2 = 9;
n3 = 2;
A1 = rand(n1, n2, n3);
A2 = rand(n1, n2, n3);
A3 = rand(n1, n2, n3);
[~, B] = max(A3, [], 3);
[C11_v1, C21_v1, C31_v1] = version1(A1, A2, A3, B, n1, n2, n3);
[C11_v2, C21_v2, C31_v2] = version2(A1, A2, A3, B, n1, n2, n3);
% compare the results
isequal(C11_v1,C11_v2)
isequal(C21_v1,C21_v2)
isequal(C31_v1,C31_v2)
function [C1, C2, C3] = version2(A1, A2, A3, B, n1, n2, n3)
[RR,CC] = ndgrid(1:n1,1:n2);
idx = sub2ind([n1,n2,n3],RR,CC,B);
C1 = A1(idx);
C2 = A2(idx);
C3 = A3(idx);
end
function [C1, C2, C3] = version1(A1, A2, A3, B, n1, n2, n3)
C1 = nan(n1, n2);
C2 = C1;
C3 = C1;
for j2 = 1:n2
for j1 = 1:n1
C1(j1, j2) = A1(j1, j2, B(j1, j2));
C2(j1, j2) = A2(j1, j2, B(j1, j2));
C3(j1, j2) = A3(j1, j2, B(j1, j2));
end
end
end
0 Kommentare
Siehe auch
Kategorien
Mehr zu Matrix Indexing finden Sie in Help Center und File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!