Fast implementation of max-plus matrix multiplication

7 Ansichten (letzte 30 Tage)
Davide Zorzenon
Davide Zorzenon am 13 Apr. 2021
Kommentiert: Seth Younger am 14 Sep. 2021
I am trying to implement a fast max-plus algebra multiplication of square matrices in MATLAB. The max-plus multiplication between two matrices returns the matrix such that .
A naive implementation of the max-plus multiplication in MATLAB is:
function C = mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
for i = 1:n
for j = 1:n
for k = 1:n
C(i,j) = max(C(i,j), A(i,k) + B(k,j));
end
end
end
However, this is not particularly fast. The fastest implementation I could come up with is:
function C = fast_mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
A = transpose(A);
for i = 1:n
C(i,:) = max(A(:,i) + B);
end
which takes about 0.18 seconds to multiply two 100×100 matrices 100 times on my computer. Since performing the same test using a "standard" multiplication takes about 0.004 seconds (45 times faster), I was wondering if there were a way to speed up the code and obtain more comparable timings using MATLAB.
If this were not possible, would it be worth to try to create a package similar to LAPACK for max-plus algebra to get faster performances? Or is this the maximum achievable speed for reasons like "the CPU is inherently slower to calculate the maximum than the product of two numbers"?
  4 Kommentare
Jan
Jan am 13 Apr. 2021
The naive max() implementation contains a branch. This slows down the processing in general, because it impedes the pipelining in the CPU, which decides for one of the branches using heuristics. In case of max() about half of the predictions are false.
The matrix multiplication is performed in highly optimized library. Even the naive implementation in Matlab profits from Matlab's JIT acceleration, which does not handle max() with the same efficiency.
Do you have a C compiler installed?
Davide Zorzenon
Davide Zorzenon am 13 Apr. 2021
Thank you for the explanation! Yes, I have it.

Melden Sie sich an, um zu kommentieren.

Akzeptierte Antwort

Jan
Jan am 13 Apr. 2021
Bearbeitet: Jan am 15 Apr. 2021
A C-mex version:
[EDITED: Check of type double is added]
// mp_prod_mex.c
// C = mp_prod_mex(A, B)
// INPUT: A, B: Real double [m x k] and [k x n] matrices.
// OUTPUT: Real double [m x n] matrix.
//
// Equivalent Matlab code:
// [m, p] = size(A); n = size(B, 2);
// C = -inf(m, n);
// for i = 1:m
// for j = 1:n
// for k = 1:p
// C(i,j) = max(C(i,j), A(i,k) + B(k,j));
// end
// end
// end
//
// COMPILE:
// mex -O -R2018a mp_prod_mex.c
//
// Jan, Heidelberg, (C) 2021, License: CC BY-SA 3.0
// Handling of rectangular matrices: Bruno Luong
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *A, *Ak, *B, *Bj, *Bk, *C, mInf = -mxGetInf();
register double s, t;
mwSize m, n, p, i, j, k;
// Get inputs:
A = mxGetDoubles(prhs[0]);
B = mxGetDoubles(prhs[1]);
m = mxGetM(prhs[0]); // size(A, 1)
n = mxGetN(prhs[1]); // size(B, 2)
p = mxGetN(prhs[0]); // size(A, 2) and size(B, 1)
if (!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]) ||
mxIsComplex(prhs[0]) || mxIsComplex(prhs[1])) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Inputs must be real double matrices.");
}
if (mxGetM(prhs[1]) != p) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Size of A and B do not match.");
}
// Create output:
plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL);
C = mxGetDoubles(plhs[0]);
// Calculate result:
for (j = 0; j < n; j++) { // Loop over B at 1st for linear access to C
Bj = B + j * n;
for (i = 0; i < m; i++) {
Ak = A + i;
Bk = Bj;
t = mInf;
for (k = 0; k < p; k++) {
s = *Ak + *Bk++;
Ak += m;
if (s > t) {
t = s;
}
}
*C++ = t;
}
}
return;
}
Timings, 100x100 input, 100 iterations:
% Matlab R2018b, i5 mobile:
Elapsed time is 0.365667 seconds. % 1st version from question
Elapsed time is 0.253912 seconds. % 2nd version from question
Elapsed time is 0.337319 seconds. % Bruno's version
Elapsed time is 0.077329 seconds. % C mex
Transposing A for a contiguous access was no significant advantage.
  12 Kommentare
Davide Zorzenon
Davide Zorzenon am 14 Apr. 2021
Bearbeitet: Davide Zorzenon am 14 Apr. 2021
@Bruno Luong: thanks for pointing out the case with matrices of dimensions and . I had never thought about that but to be consistent the result of the max-plus multiplication shoud be -inf(m,n) as you wrote.
Seth Younger
Seth Younger am 14 Sep. 2021
@Jan I would love to connect with you and learn more about how this implementation works. Are you slicing the matrix up? I am confused by the (A,1) and (A,2).

Melden Sie sich an, um zu kommentieren.

Weitere Antworten (1)

Bruno Luong
Bruno Luong am 13 Apr. 2021
Bearbeitet: Bruno Luong am 13 Apr. 2021
function C = mp_prod(A,B)
m=size(A,1);
n=size(B,2);
AA=reshape(A,m,1,[]);
BB=reshape(B.',1,n,[]);
C=max(AA+BB,[],3);
tic/toc result
>> A=rand(100);
>> B=rand(100);
>> tic; C = mp_prod(A,B); toc
Elapsed time is 0.005206 seconds.
  8 Kommentare
Jan
Jan am 14 Apr. 2021
Bearbeitet: Jan am 14 Apr. 2021
@Davide Zorzenon: Which machine and Matlab version are you using?
The different timings can mean, that Davide's setup is more efficient for the loop, or Bruno's setup is more efficient for the vectorized solution.
Davide Zorzenon
Davide Zorzenon am 14 Apr. 2021
@Jan: I use Matlab R2019a and my computer is a Windows 10 x64 intel core i7 (2.20GHz), with 16GB RAM and 6 cores.

Melden Sie sich an, um zu kommentieren.

Kategorien

Mehr zu Parallel Computing Toolbox finden Sie in Help Center und File Exchange

Produkte

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by