# How to multiply N matrices without a FOR loop? (Slices of 3D array)

124 views (last 30 days)
Edgars Stegenburgs on 6 Dec 2017
Answered: Steven Lord on 17 Sep 2020
I have a 3D matrix 2x2xN which, for my purposes, are essentially N 2x2 matrices and I want to do matrix multiplication with all of them so that I would get the following result:
N = 14;
M = rand(2,2,N);
Z = M(:,:,1)*M(:,:,2)* ... *M(:,:,N);
size(Z) == [2 2]
I can do it with a for loop, but I am looking for a single line approach, something like:
prod(M,3);
but probably with mtimes that would do matrix multiplication along the 3rd dimension (not the element-wise product).
I also converted matrix M into a Nx1 cell array of 2x2 matrices, but this approach did not work either to do the multiplication.

Andrei Bobrov on 7 Dec 2017
@Stephen: +1
Edgars Stegenburgs on 7 Dec 2017
Thanks for the comment, I will definitely keep this in mind. I may experiment with the sizes of my matrices to see when it is faster by using for loops, but so far it was much slower.
Of course, in either case, you need to take the slices. If there was a MATLAB function that would do what I wanted, maybe it could have some ingenious way how it is handling the work, just like MATLAB is very fast in processing arrays.
Because execution time is fluctuating slightly, I always benchmark self-repeated code 10x,100x,..., depending on the initial required time for execution, and see if there is a noticeable difference, and because this as well is not constant, I do it several times.
In the end, I am maintaining my code and keeping it in the simplest form, quite readable, and keeping it short as well, that helps for debugging purposes, seeing more at a glance, which in turn somewhat improves understanding the code. Elegant...
Jan on 7 Dec 2017
Stephen's comment is very good.
For the estimation of the effects of optimizing the code, the usual sizes of the inputs matter: Is it really a [2 x 2 x N] array and what sizes of N do you have? For larger rows and columns, the main is done by mtimes, while the loop does not matter much. mtimes calls optimized BLAS or ATLAS functions, such that there is no room for further improvements. But I do not know, if these library function handle tiny 2x2 matrices with unrolled loops. So perhaps a C-Mex function could be more efficient.

Jan on 7 Dec 2017
Edited: Jan on 7 Dec 2017
If you really have 2x2 sub matrices to accumulate, try a C-Mex function:
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
const mwSize *size;
mwSize N;
double *p, *q, q11, q12, q21, q22, t11, t21;
p = mxGetPr(prhs);
size = mxGetDimensions(prhs);
if (size != 2 || size != 2) {
"1st input must be a [2 x 2 x N] array.");
}
N = size;
q11 = p;
q21 = p;
q12 = p;
q22 = p;
while (--N) { // Unrolled 2x2 matrix multiplication
p += 4;
t11 = q11 * p + q12 * p;
t21 = q21 * p + q22 * p;
q12 = q11 * p + q12 * p;
q22 = q21 * p + q22 * p;
q11 = t11;
q21 = t21;
}
plhs = mxCreateDoubleMatrix(2, 2, mxREAL);
q = mxGetPr(plhs);
q = q11;
q = q21;
q = q12;
q = q22;
return;
}
[EDITED] This is tested now. The speed is very interesting:
function speed
x = rand(2, 2, 1000);
tic; for k = 1:1000, y = CumMProd2x2(x); end; toc
tic; for k = 1:1000, y = CumMProd2x2_AB(x); end; toc
tic
for k = 1:1000 % Jos (10584)
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf, x, size(x, 3));
end
toc
end
function out = CumMProd2x2_AB(M) % Andrei Bobrov
s = size(M, 3);
out = M(:,:,1);
for ii = 2:s
out = out * M(:,:,ii);
end
end
R2016b/64/Win7:
Elapsed time is 0.011403 seconds. C-mex
Elapsed time is 3.884977 seconds. Loop
Elapsed time is 96.038754 seconds. Recursive anonymous function
I was surprised, that Andrei's loop is such slow, although it is clearly the nicest and cleaned solution. Let's try to unroll the loops like in the C-Code:
function out = CumMProd2x2_unroll(M)
q11 = M(1);
q21 = M(2);
q12 = M(3);
q22 = M(4);
c = 1;
for ii = 2:size(M, 3)
c = c + 4;
t11 = q11 * M(c) + q12 * M(c+1);
t21 = q21 * M(c) + q22 * M(c+1);
q12 = q11 * M(c+2) + q12 * M(c+3);
q22 = q21 * M(c+2) + q22 * M(c+3);
q11 = t11;
q21 = t21;
end
out = [q11, q12; q21, q22];
end
This 64 times faster than the direct approach "out * M(:,:,ii)":
Elapsed time is 0.061287 seconds. Unrolled
Obviously Matlab calls very smart highly optimized libraries for the matrix multiplication, which treat the tiny input with the same hammer method as a 1000x1000 matrix.
But this unrolled version is such ugly, that I would hesitate to use it in productive code. For x = rand(2, 2, 100000) I get the timings for 1000 iterations:
Elapsed time is 1.377695 seconds. C-mex
Elapsed time is 2.872356 seconds. M with unrolled mtimes
Only a factor 2! Another example, that loops are not such bad in Matlab compared to C.

Andrei Bobrov on 7 Dec 2017
@Jan: It's great! +1
Jos (10584) on 7 Dec 2017
haha, I really liked my anonymous function approach, and did expect it to perform poorly, but that poor ... haha

Andrei Bobrov on 6 Dec 2017
s = size(M)
out = M(:,:,1);
for ii = 2:s(3)
out = out*M(:,:,ii);
end

Edgars Stegenburgs on 6 Dec 2017
parfor cannot be used for matrix multiplication since it has to be done in the original sequence, the following relation is not true with matrices: A*B=B*A.
Jan on 7 Dec 2017
+1: This is the nicest solution. That the multiplication of 2x2 matrices is much faster with hard coded algorithm is not a problem of this solution.
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Matt J on 7 Dec 2017
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Just wanted to note that, while my solution based on MTIMESX is not as fast as Jan's for the 2x2xN case, it is applicable to arbitrary MxMxN arrays,

Matt J on 6 Dec 2017
The following is not a one-line solution (for that just stick it in a function file) and requires MTIMESX from the File Exchange. However, I do see a few factors speed-up over a conventional for-loop,
out=M;
while size(out,3)>1
n=size(out,3);
if mod(n,2)
n=n-1;
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=cat(3,mtimesx(A,B),out(:,:,n+1));
else
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=mtimesx(A,B);
end
end

Matt J on 7 Dec 2017
When I use the mtimesx SPEED flag, things actually get considerably faster:
N=1e7;
M=rand(2,2,N);
M=M./sum(M,2);
%%%Baseline implementation
tic;
out0=eye(2);
for k=1:N
out0=out0*M(:,:,k);
end
toc;
Elapsed time is 12.285774 seconds.
%%%With mtimesx
mtimesx SPEED
tic
out=M;
n=size(out,3);
while n>1
if mod(n,2)
C=out(:,:,n);
n=n-1;
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
B(:,:,end)=B(:,:,end)*C;
out=mtimesx(A,B);
else
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=mtimesx(A,B);
end
n=size(out,3);
end
toc
Elapsed time is 0.785595 seconds.
James Tursa on 7 Dec 2017
Side Note: MTIMESX by default calls BLAS library routines for matrix multiply so that it matches MATLAB for-loop m-code result, whereas MTIMESX with the 'SPEED' option will use hand-coded inline matrix multiply code for up to 5x5 size slices which may not match MATLAB for-loop m-code result exactly.
Sometime back I had a beta version of MTIMESX that implemented the matrix equivalent versions of 'prod' and 'cumprod'. Maybe it is time I dust that off and finish the implementation/testing so I can publish it.
Matt J on 7 Dec 2017
That is strange, since I still see significant speed-up even with
mtimesx MATLAB

Jos (10584) on 6 Dec 2017
Here is one using recursion without a for-loop; not faster though, and somewhat mysterious, but just nice :) ...
M = randi(5,[2 2 4]) ; % data
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf,M,size(M,3)) % voila, it works!

Edgars Stegenburgs on 7 Dec 2017
Yep, this is really mysterious :D
Jos (10584) on 7 Dec 2017
It is the inline version of this recursive m-file:
function X = mprod(M,n)
% X = mprod(M) returns M(:,:,1) * M(:,:,2) * ... * M(:,:,end)
% where M is a 3D array
if nargin==1
X = mprod(M,size(M,3)) ;
elseif n < 2
X = M(:,:,1) ;
else
X = mprod(M,n-1) * M(:,:,n) ;
end
Edgars Stegenburgs on 7 Dec 2017
Look very nice, albeit, I guess this will be slower due to recursivity.

Steven Lord on 17 Sep 2020
If you're using release R2020b or later, take a look at the pagemtimes function introduced in that release.