Matrix Optimal Multiplication in Compile Time by C++ template

What’s in this Article?

In this article, I discussed an algorithm to solve the matrix optimal multiplication problem.

In the article, I will implement the algorithm in compile time by C++ template meta-programming.

A refresher of the matrix optimal multiplication problem

Consider a matrix multiplication problem.

\(A B C \)

Where \(A, B, C\) are matrices. \(A \in R^{(40, 30)}\), \(B \in R^{(30, 20)}\), \(C \in R^{(20, 10)}\).

We can evaluate the multiplication by ((A * B) * C) which takes 32000 floating-point multiplications. We can also evaluate it by (A * (B * C)) which takes 18000 floating-point multiplications.

The goal of the matrix optimal multiplication problem is to find the matrix multiplication order such that the number of floating-point multiplications is minimized.

In the article, I solved the problem in Python in several steps,

  1. Use an expression tree to represent a matrix multiplication.
  2. Given a expression tree, we can evaluate the number of multiplications to evaluate the expression tree.
  3. We can find the optimal expression by iterating all expression trees and computing its number of multiplications.
  4. We can make the searching process faster by using memorization (dynamic programming).

The Python implementation solves the problem in runtime. However, If we know matrices size in compile time, we should able to solve the problem in compile time. In other word, in compile time, we find the optimal expression and compile it into machine code. In runtime, we simply execute the machine code.

In this article, I will solve the Matrix Optimal Multiplication Problem in compile time by using C++ template and constexpr.

You need to know basic C++ template to understand this article. I won’t go into the algorithmic detail of the problem. Please read this article about the matrix optimal multiplication algorithm.

The outline of the article

  1. explain some useful C++ template & constexpr techniques.
  2. Implement the expression tree in compile-time.
  3. Find the minimum of all expression trees in compile-time.
  4. A way to evaluate a compile-time expression in run-time.

Code link.

1. C++ Compile-Time Tricks

1.1 Template

From wikipedia: Templates are a feature of the C++ programming language that allows functions and classes to operate with generic types. This allows a function or class to work on many different data types without being rewritten for each one.

To be honest, C++ template is simply text substitution.

For example,

template<typename T>
T add(T a, T b) {
    return a + b
}

int main() {
    int a, b;
    add(a, b);
    
    float c, d;
    add(c, d);
}

The compiler will infer the type of T from the context and substitute T by the inferred type. The above coded is equal to,

int add_int(int a, int b) {
    return a + b;
}

float add_float(float a, float b) {
    return a + b;
}

int main() {
    int a, b;
    add_int(a, b);
    
    float c, d;
    add_float(c, d);
}

The compiler basically does a text substitute for T by the inferred type. It is done in compile-time. In run-time, we simply execute the substituted function.

Template is just text substitute!

Consider another example where we assume a type has a member variable cost.

struct S1 {
  int cost = 1;  
};

struct S2 {
  int cost = 1;  
};

struct S3 {
  float cost = 1;  
};

template<typename A, typename B>
int add_cost(A a, B b) {
    // Note that we don't know what type A is. 
    // All we know is A has a member variable named cost.
    return a.cost + b.cost;
}


int main() {
    S1 s1;
    S2 s2;
    S3 s2;
    
    add_cost(s1, s2);
    add_cost(s2, s3);
    add_cost(s1, s3);
}

The compiler will infer the type and do a text substitution. The above code is equivalent to,

struct S1 {
  int cost = 1;  
};

struct S2 {
  int cost = 1;  
};

struct S3 {
  float cost = 1;  
};

int add_cost(S1 a, S2 b) {
    return a.cost + b.cost;
}

int add_cost(S2 a, S3 b) {
    return a.cost + b.cost;
}

int add_cost(S1 a, S3 b) {
    return a.cost + b.cost;
}

int main() {
    S1 s1;
    S2 s2;
    S3 s2;
    
    // line 13
    add_cost(s1, s2);
    // line 17
    add_cost(s2, s3);
    // line 21
    add_cost(s1, s3);
}

No matter how complex the template is, the compiler will substitute the type and try to compile the substituted code.

We can also define rules for the text substitution. For example, we can do a specific substitution for a type while do the default substitution for other types. In this article, we are not going to use these techniques.

1.2 Compile-time constant: constexpr

In C++, we can define compile-time constants by the keyword constexpr. Moreover, we can write constexpr functions to do computations in compile-time.

For example, the greatest common divisor can be implemented as,

constexpr int gcd(int a, int b){
  while (b != 0){
    auto t= b;
    b= a % b;
    a= t;
  }
  return a;
}

int main(){
  constexpr int i= gcd(11,121);
}

Because the inputs for gcd are constant at compile time, the result is computed at compile-time. The keyword constexpr did the trick. In runtime, the code is equivalent to,

int main(){
  // The gcd() is computed in compile time.
  constexpr int i = 11;
}

In the next section, we will use the C++ template and constexpr to implement the expression tree for a matrix multiplication expression.

2. The Expression Tree

The expression tree is the standard way to represent an expression in a programming language.

For example, the expression trees for these matrix multiplication expressions are,

@ means multiplication.

expression: (((A@(B@(C@D)))@E)@F)
expression tree:
        _@ 
       /  \
  _____@  F
 /      \  
 @_     E  
/  \       
A  @_      
  /  \     
  B  @     
    / \    
    C D    


expression: ((((A@B)@(C@D))@E)@F)
expression tree:
        _@ 
       /  \
    ___@  F
   /    \  
  _@_   E  
 /   \     
 @   @     
/ \ / \    
A B C D    


This is the default expression: A @ B @ C @ D @ E @ F
expression: (((((A@B)@C)@D)@E)@F)
expression tree:
        _@ 
       /  \
      _@  F
     /  \  
    _@  E  
   /  \    
  _@  D    
 /  \      
 @  C      
/ \        
A B        

In C++, we can use a type to represent matrix expression and a type to represent the product(multiplication) expression.

The matrix expression M represents a matrix variable.

The product expression Prod<L, R> represents a matrix multiplication. It has a left operand and a right operand.

We can combine them together.

The expression A * (B * C) can be represented as Prod<MA, Prod<MB, MC>>. Note, Prod<MA, Prod<MB, MC>> is a C++ type which is generated at compile time. The type implicitly defines an expression tree.

In compile-time, with some tricks, we can do tree operations on the expression type. e.g. traverse the tree to compute the total number of multiplications for the expression. This is the motivation to define expressions as types.

In the next section, I will define the matrix expression M, and the product expression Prod<A, B>.

2.1 The Matrix Expression

The matrix expression represents a matrix variable in a expression. e.g. In a expression: (A * B) * C, A, B and C are matrix expressions.

The matrix expression can be implemented as,

// Using a template alising for Eigen::Matrix
template <int32_t R, int32_t C>
using Matrix = Eigen::Matrix<double, R, C>;

// The MatExpression is a template struct of ROWS and COLS.
template <int32_t ROWS, int32_t COLS>
struct MatExpression {
    // Used in Min.
    using Type = MatExpression;
    
    // Alising for Eigen::Matrix again.
    using MatrixType = Matrix<ROWS, COLS>;

    // Important!
    // We can deduct the ROWS and COLS from input Eigen::Matrix
    MatExpression(MatrixType&)
    {
    }

    ...

    static constexpr int32_t cost()
    {
        return 0;
    }

    static constexpr int32_t rows()
    {
        return ROWS;
    }

    static constexpr int32_t cols()
    {
        return COLS;
    }

    ...
};

The MatExpression is a template of ROWS and COLS. In its constructor, we deduct its ROWS and COLS from an Eigen::Matrix.

The interesting parts are the 3 constexpr functions. cost() function tracks the number of multiplications. For MatExpression it’s always 0. The rows() and cols() track the size of the matrix.

Note that cost(), rows() and cols() are constexpr functions. It means we can do compile-time computation using these functions.

Also note that the struct doesn’t contain any member variables. The usage of the type is to (1) track number of multiplications and (2) track size of matrix. The usage will be clear in section 5.

2.2 The Product Expression

The product expression can be implemented as,

template <typename A, typename B>
struct ProdExpression {
    // Used in struct Max/ struct Min
    using Type = ProdExpression;
    using MatrixType = Matrix<A::rows(), B::cols()>;

    ...
    
    // deduct type A & B from inputs
    ProdExpression(A, B)
    {
    }

    static constexpr int32_t cost()
    {
        return A::cost() + B::cost() + A::rows() * A::cols() * B::cols();
    }

    static constexpr int32_t rows()
    {
        return A::rows();
    }

    static constexpr int32_t cols()
    {
        return B::cols();
    }

    ...
};

The product expression takes 2 expressions as inputs and propagates the size of matrices. Note the cost(), rows() and cols() are constexpr functions. It means these functions are evaluated at compile-time.

For example, we can substitute type A and type B by 2 matrices.

using M1 = MatExpression<10, 10>;
using M2 = MatExpression<10, 5>;

using Product_of_2_mat = ProdExpression<M1, M2>;

// Product_of_2_mat is equal to this type
//   type(A) = M1
//   type(B) = M2
// We will call A::cost(), A::rows(), A::cols(), B::cost(), B::rows() and B::cols().
struct ProdExpression_Subsituted {
    // Used in struct Max/ struct Min
    using Type = ProdExpression;
    using MatrixType = Matrix<M1::rows(), M2::cols()>;

    ...

    static constexpr int32_t cost()
    {
        return M1::cost() + M2::cost() + M1::rows() * M1::cols() * M2::cols();
    }

    static constexpr int32_t rows()
    {
        return M1::rows();
    }

    static constexpr int32_t cols()
    {
        return M2::cols();
    }

    ...
};

Then the production expression represents A * B . The cost() compute the number of muls to evaluate A * B . The rows() and cols() compute the size of the result of A * B.

The product expression also takes product expression as input. For example,M0 * (M1 * M2) can be represented as ProdExpression<M0, ProdExpression<M1, M2>>. The outer expression can be represented as,

using M0 = MatExpression<8, 10>;
using M1 = MatExpression<10, 10>;
using M2 = MatExpression<10, 5>;

using Product_of_3_mat = ProdExpression<M0, ProdExpression<M1, M2>>;

// Product_of_3_mat is equal to this type
//   type(A) = M0
//   type(B) = ProdExpression<M1, M2>
// We will call A::cost(), A::rows(), A::cols(), B::cost(), B::rows() and B::cols().
struct ProdExpression_Subsituted {
    // Used in struct Max/ struct Min
    using Type = ProdExpression;
    using MatrixType = Matrix<M0::rows(), ProdExpression<M1, M2>::cols()>;

    ...

    static constexpr int32_t cost()
    {
        return M0::cost() + ProdExpression<M1, M2>::cost() 
            + M0::rows() * M0::cols() * ProdExpression<M1, M2>::cols();
    }

    static constexpr int32_t rows()
    {
        return M0::rows();
    }

    static constexpr int32_t cols()
    {
        return ProdExpression<M1, M2>::cols();
    }

    ...
};

As a result, we can represent a matrix multiplication expression by using a combination of ProdExpression and the MatExpression.

If we want to know the number of multiplications to evaluate an expression, we can simply call the cost() method in the top-level product expression type. For example, we can call ProdExpression<M0, ProdExpression<M1, M2>>::cost() to know the number of multiplications to evaluate M0 * (M1 * M2).

3 Compile-Time Min function

In the previous section, we defined a way to compute the number of multiplications for a matrix product expression.

Since our goal is to find the optimal expression. The problem become,

Given some expressions, how to find the expression with the minimum cost in compile-time?

For example, for 3 matrices optimal multiplication problem, there are 2 possible expressions,

Expression: A * (B * C)
C++ type: ProdExpression<A, ProdExpression<B, C>>
Expression Tree:
 @_  
/  \ 
A  @ 
  / \
  B C

Expression: (A * B) * C
C++ type: ProdExpression<ProdExpression<A, B> C>
Expression Tree:
  _@ 
 /  \
 @  C
/ \  
A B

We want to find the expression with the minimum cost. i.e.

Min(ProdExpression>::cost(), ProdExpression C>::cost())

(maybe I should use ArgMin.)

The problem is how to do the minimum operation in compile-time.

My solution is to define a type aliasing conditionally.

namespace matrix_optimal_product {
template <typename A, typename B>
struct Min {
    static constexpr int32_t cost()
    {
        // maybe I should use static_cast
        return ((int32_t)A::cost() < (int32_t)B::cost()) ? (int32_t)A::cost() : (int32_t)B::cost();
    }

    // The conditional<x, A, B>::type is A or B.
    // The reasion to use XX::type::Type is to propagate the Min operation.
    // Without it, Min<Min<A,B>, C>::Type can be Min<A,B> or C.
    // I defined A::Type = A; B::Type = B; With XX::type::Type,
    // Min<Min<A,B>, C>::Type is A or B or C.
    using Type = typename std::conditional<(A::cost() < B::cost()), A, B>::type::Type;

    ...
};

The std::conditional<condition, type1, type2> does the trick.

When condition is true the std::conditional<condition, type1, type2>::type is type1.

When condition is false the std::conditional<condition, type1, type2>::type is type2.

What the code does is,

  1. return the minimum cost for type A and type B. It is the min operation.
  2. define the Type to be the type with minimum cost. It is the argmin operation.

If we want to know which expression corresponds to the minimum cost, The optimal expression is defined by (or aliased by) Min<E1, E2>::Type.

Why we need to return the cost()? By doing that, the Min type can take other Min type as its template arguments. So we can compute the min of more than 2 expressions. For example, the min of all possible expressions for multiplying M1 * M2 * M3 * M4 can be written as,

// The min of all possible expressions for 4 matrices
using MinExpression4 = 
    Min<Prod<M1, Prod<M2, Prod<M3, M4>>>,
        Min<Prod<M1, Prod<Prod<M2, M3>, M4>>,
            Min<Prod<Prod<M1, M2>, Prod<M3, M4>>,
                Min<Prod<Prod<M1, Prod<M2, M3>>, M4>,
                    Prod<Prod<Prod<M1, M2>, M3>, M4>>>>>;

// The expression with the minimum cost is                    
using OptimalExpression = MinExpression4::Type;

4. Generate all Expressions

In the previous article, I did a recursive call to generate all expressions for a list of matrices. i.e. Generate all possible trees for a post-order traversal.

Then, we loop though the expressions and found the expression with the minimum cost.

def build_all_expression_tree(vars):
    '''
        generate all trees for a list of mats
    '''
    if len(vars) == 1:
        return [VarNode(vars[0])]

    trees = []

    for i in range(1, len(vars)):
        left_trees = build_all_expression_tree(vars[:i])
        right_trees = build_all_expression_tree(vars[i:])

        for l in left_trees:
            for r in right_trees:
                root = OpNode('Prod')
                root.left = l
                root.right = r
                root.expression_string = \
                    'Prod<{},{}>'.format(root.left.expression_string,
                                     root.right.expression_string)
                trees.append(root)
    return trees

In C++, I don’t know how to implement the algorithm in compile-time.

So I hard-coded it.

I modified the Python implementation to generate C++ type expressions.

For example, the Python script call code_gen.py outputs the Min of all expressions for a fixed number of matrices.

template<typename A, typename B>
using Prod = ProdExpression<A, B>;

// case for, M1 * M2 * M3
template<typename M1, typename M2, typename M3>
using MinExpression3 = Min<Prod<M1, Prod<M2, M3>>, Prod<Prod<M1, M2>, M3>>;

// case for, M1 * M2 * M3 * M4
template<typename M1, typename M2, typename M3, typename M4>
using MinExpression4 = Min<Prod<M1, Prod<M2, Prod<M3, M4>>>,
    Min<Prod<M1, Prod<Prod<M2, M3>, M4>>,
        Min<Prod<Prod<M1, M2>, Prod<M3, M4>>,
            Min<Prod<Prod<M1, Prod<M2, M3>>, M4>,
                Prod<Prod<Prod<M1, M2>, M3>, M4>>>>>;

// case for, M1 * M2 * M3 * M4 * M5
template<typename M1, typename M2, typename M3, typename M4, typename M5>
using MinExpression5 = Min<Prod<M1, Prod<M2, Prod<M3, Prod<M4, M5>>>>,
    Min<Prod<M1, Prod<M2, Prod<Prod<M3, M4>, M5>>>,
        Min<Prod<M1, Prod<Prod<M2, M3>, Prod<M4, M5>>>,
            Min<Prod<M1, Prod<Prod<M2, Prod<M3, M4>>, M5>>,
                Min<Prod<M1, Prod<Prod<Prod<M2, M3>, M4>, M5>>,
                    Min<Prod<Prod<M1, M2>, Prod<M3, Prod<M4, M5>>>,
                        Min<Prod<Prod<M1, M2>, Prod<Prod<M3, M4>, M5>>,
                            Min<Prod<Prod<M1, Prod<M2, M3>>, Prod<M4, M5>>,
                                Min<Prod<Prod<Prod<M1, M2>, M3>, Prod<M4, M5>>,
                                    Min<Prod<Prod<M1, Prod<M2, Prod<M3, M4>>>, M5>,
                                        Min<Prod<Prod<M1, Prod<Prod<M2, M3>, M4>>, M5>,
                                            Min<Prod<Prod<Prod<M1, M2>, Prod<M3, M4>>, M5>,
                                                Min<Prod<Prod<Prod<M1, Prod<M2, M3>>, M4>, M5>,
                                                    Prod<Prod<Prod<Prod<M1, M2>, M3>, M4>, M5>>>>>>>>>>>>>>;
                                                    
template<typename M1, typename M2, typename M3, typename M4, typename M5, typename M6>
using MinExpression6 = Min<Prod<M1, Prod<M2, Prod<M3, Prod<M4, Prod<M5, M6>>>>>,
    Min<Prod<M1, Prod<M2, Prod<M3, Prod<Prod<M4, M5>, M6>>>>,
      ...........................
      ..... crazy stuff !! ......
      ...........................
        Prod<Prod<Prod<Prod<Prod<M1, M2>, M3>, M4>, M5>, M6>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>;
                                                    
...

Hard coding (by code generation) is actually quite common for C++ template library. So don’t be mad. If I can implement something to generate all expression in compile-time, it should be equivalent to hard-coding all expressions.

Remember in the Python implementation, we used the memorization (dynamic programming) to speedup the cost computation for all trees. In the C++ implementation, we get the memorization for free. For each type, the compiler should only computes its constexpr cost() function once. Because the return value of a constexpr function is always the same, the compile can simply save the result and return it whenever possible. (But I am not sure whether compilers does that.)

Now, by using the generated expressions, we can compute the optimal expression to compute matrix multiplication in compile-time.

5. Evaluation a Expression in Run-time

Given an expression type, we want to evaluate the expression in run-time.

It can be done by function overloading and recursion calls.

// For product node
template <typename A, typename B>
typename ProdExpression<A, B>::MatrixType eval_expression(ProdExpression<A, B>, void* p[], int& idx)
{
    // 1. If type(A) == ProdExpression, call function on line 3.
    // 2. If type(A) != ProdExpression (i.e. type(A) == Mat<rows, cols>), call function on line 13
    // The evaluation is a post order traversal of the expression tree.
    typename A::MatrixType m1 = eval_expression(A(), p, idx);
    typename B::MatrixType m2 = eval_expression(B(), p, idx);
    return m1 * m2;
}

// For leaf node
template <typename M>
typename M::MatrixType eval_expression(M, void* p[], int& idx)
{
    // typename tells the compiler M::MatrixType is a type instead of a member variable.
    using EigenMatType = typename M::MatrixType;
    // Given type EigenMatType, we can cast the void* to EigenMatType*. 
    // Not safe. You have to trust me!
    return *reinterpret_cast<EigenMatType*>(p[idx++]);
}

The idea is postorder tree traversal in compile time.

  1. For ProdExpression, we evaluate it’s left and right children first. Then we return the product of them.
  2. We track the index of the postorder traversal array (which is the user input matrices). Whenever, we reach a leaf node (always a MatExpression), we evaluate the leaf node by return the Eigen::Matrix in the postorder traversal array.

Note we can replace the function overloading by template specification. But I’d like to keep it simple.

To hide the idx, I added a entry function.

template <typename A, typename B>
typename ProdExpression<A, B>::MatrixType eval_expression(ProdExpression<A, B> E, void* p[])
{
    int idx = 0;
    return eval_expression(E, p, idx);
}

Note that the eval_expression function is generated recursively at compile time. In run time, we just need to feed data to it.

6. Put Everything Together

Now we can put everything together by defining the product function which users call to find the optimal expression in compile-time and evaluate the expression in run-time.

For the product of 4 matrices, the product function is defined as,

template <typename EigenM1, typename EigenM2, typename EigenM3, typename EigenM4>
auto prod(EigenM1& m1, EigenM2& m2, EigenM3& m3, EigenM4& m4, bool verbose = false)
{
    auto me1 = MatExpression(m1);
    auto me2 = MatExpression(m2);
    auto me3 = MatExpression(m3);
    auto me4 = MatExpression(m4);

    using M1 = decltype(me1);
    using M2 = decltype(me2);
    using M3 = decltype(me3);
    using M4 = decltype(me4);
    using MinExpression = MinExpression4<M1, M2, M3, M4>;
    using OptimalExpression = typename MinExpression::Type;

    if (verbose) {
        print_cost_for_min_expression(MinExpression());
        std::cout << "OptimalExpression:" << std::endl;
        OptimalExpression::print_expression();
        std::cout << std::endl;
        std::cout << "optimal cost:" << OptimalExpression::cost() << std::endl;
    }

    void* p[4] = { &m1, &m2, &m3, &m4 };
    return eval_expression(OptimalExpression(), p);
}

The inputs are 4 Eigen::Matrixs.

In line 9~12, we deduct the size of input matrices in compile time.
In line 13, MinExpression is the hard coded crazy stuff to compute the optimal expression in compile-time. See section 4 for details.
In line 14, we call the optimal expression as OptimalExpression.
In line 24~25, we evaluate the optimal expression in run-time. See section 5 for details.
The return type is an Eigen::Matrix. Note I used a C++ 17 feature to auto deduct the return type to save my hair.

I am not sure how to write a generic function for product with different size of inputs. So I end up write a product function for each number of matrices.

namespace matrix_optimal_product {
template <typename EigenM1, typename EigenM2, typename EigenM3>
auto prod(EigenM1& m1, EigenM2& m2, EigenM3& m3, bool verbose = false)
{ ... }

template <typename EigenM1, typename EigenM2, typename EigenM3, typename EigenM4>
auto prod(EigenM1& m1, EigenM2& m2, EigenM3& m3, EigenM4& m4, bool verbose = false)
{ ... }

template <typename EigenM1, typename EigenM2, typename EigenM3, typename EigenM4, typename EigenM5>
auto prod(EigenM1& m1, EigenM2& m2, EigenM3& m3, EigenM4& m4, EigenM5& m5, bool verbose = false)
{ ... }

template <typename EigenM1, typename EigenM2, typename EigenM3, typename EigenM4, typename EigenM5, typename EigenM6>
auto prod(EigenM1& m1, EigenM2& m2, EigenM3& m3, EigenM4& m4, EigenM5& m5, EigenM6& m6, bool verbose = false)
{ ... }

I guess it’s a common (ugly) practice for template.

7. Compile and Run

Firstly let’s make sure the computation for expressions happens in compile-time.

void testCompilationTime()
{
    std::cout << "============ testCompilationTime =============" << std::endl;
    auto m1 = Matrix<19, 9>();
    auto m2 = Matrix<9, 8>();
    auto m3 = Matrix<8, 7>();

    auto me1 = MatExpression(m1);
    auto me2 = MatExpression(m2);
    auto me3 = MatExpression(m3);
    using M1 = decltype(me1);
    using M2 = decltype(me2);
    using M3 = decltype(me3);
    using MinExpression = MinExpression3<M1, M2, M3>;

    print_cost_for_min_expression(MinExpression());

    using OptimalExpression = typename MinExpression::Type;

    std::cout << "OptimalExpression:" << std::endl;
    OptimalExpression::print_expression();
    std::cout << std::endl;
    std::cout << "optimal cost:" << OptimalExpression::cost() << std::endl;

    // make sure the cost is computed at compilation time.
    static_assert(OptimalExpression::cost() == 1701, "cost is not computed at complie time!");
    std::cout << "cost computed at compilation time" << std::endl;
}

In line 26 we make sure the OptimalExpression::cost() is computed in compile-time by doing a static_assert. If the compiler can’t access OptimalExpression::cost() in compile time, it will output a error.

The outputs,

============ testCompilationTime =============
Expression: ([mat,size:19,9] * ([mat,size:9,8] * [mat,size:8,7]))
cost:1701
Expression: 
(([mat,size:19,9] * [mat,size:9,8]) * [mat,size:8,7])
cost:2432
OptimalExpression:
([mat,size:19,9] * ([mat,size:9,8] * [mat,size:8,7]))
optimal cost:1701
cost computed at compilation time

Then, let’s run some tests to make sure the result is correct.

void testCorrectness()
{
    std::cout << "=========== testCorrectness ===========" << std::endl;

    auto m1 = Matrix<19, 9>();
    auto m2 = Matrix<9, 8>();
    auto m3 = Matrix<8, 7>();
    auto m4 = Matrix<7, 6>();
    auto m5 = Matrix<6, 6>();
    auto m6 = Matrix<6, 10>();

    // TODO: using Matrix<6, 6>::Random makes the template deduction failed.
    m1 = m1.Random();
    m2 = m2.Random();
    m3 = m3.Random();
    m4 = m4.Random();
    m5 = m5.Random();
    m6 = m6.Random();

    auto res3 = prod(m1, m2, m3);
    auto gt_res3 = m1 * m2 * m3;
    assert(std::abs((res3 - gt_res3).norm()) < EPSILON && "matrix mul 3 failed");
    std::cout << "matrix mul 3 passed" << std::endl;

    auto res4 = prod(m1, m2, m3, m4);
    auto gt_res4 = m1 * m2 * m3 * m4;
    assert(std::abs((res4 - gt_res4).norm()) < EPSILON && "matrix mul 4 failed");
    std::cout << "matrix mul 4 passed" << std::endl;

    auto res5 = prod(m1, m2, m3, m4, m5);
    auto gt_res5 = m1 * m2 * m3 * m4 * m5;
    assert(std::abs((res5 - gt_res5).norm()) < EPSILON && "matrix mul 5 failed");
    std::cout << "matrix mul 5 passed" << std::endl;

    auto res6 = prod(m1, m2, m3, m4, m5, m6);
    auto gt_res6 = m1 * m2 * m3 * m4 * m5 * m6;
    assert(std::abs((res6 - gt_res6).norm()) < EPSILON && "matrix mul 6 failed");
    std::cout << "matrix mul 6 passed" << std::endl;
}

I compared the results from the product function with direct Eigen matrix multiplication.

In the end, let’s profile the running time.

void profile()
{
    std::cout << "============ timing =============" << std::endl;

    auto m1 = Matrix<50, 50>();
    m1 = m1.Random();
    auto m2 = Matrix<50, 20>();
    m2 = m2.Random();
    auto m3 = Matrix<20, 50>();
    m3 = m3.Random();
    auto m4 = Matrix<50, 6>();
    m4 = m4.Random();
    auto m5 = Matrix<6, 6>();
    m5 = m5.Random();

    // This is not is best way to profile a function.
    constexpr int iters = 1e5;

    for (int i = 0; i < iters; ++i) {
        Timing t("opti-mul");
        Eigen::MatrixXd res = prod(m1, m2, m3, m4, m5);
    }

    for (int i = 0; i < iters; ++i) {
        Timing t("eigen-mul");
        Eigen::MatrixXd res = m1 * m2 * m3 * m4 * m5;
    }

    std::map<std::string, double> timing_info = Timing::print_info();

    std::cout << "optimal mul speed-up: " 
        << timing_info.at("eigen-mul") / timing_info.at("opti-mul") << std::endl;

    // Compute predicted speed-up by muls.
    ...
    using MinExpression = MinExpression5<M1, M2, M3, M4, M5>;
    using OptimalExpression = typename MinExpression::Type;
    // The optimal order
    std::cout << "# optimal muls:" << OptimalExpression::cost() << std::endl;
    // The default order: M1 * M2 * M3 * M4 * M5
    using DefaultExpression = Prod<Prod<Prod<Prod<M1, M2>, M3>, M4>, M5>;
    std::cout << "# default muls:" << DefaultExpression::cost() << std::endl;

    std::cout << "predicted speed-up by # muls: "
              << DefaultExpression::cost() / static_cast<double>(OptimalExpression::cost()) << std::endl;
}

In line 19~22, I run the optimal matrix multiplication 1e5 times.
In line 24~27, I run the default Eigen matrix multiplication 1e5 times.
I used a RAII profiler to record the running time.
In line 39, I printed the cost (number of muls) for the optimal expression.
In line 42, I printed the cost (number of muls) for the default expression.
In line 44, I computed the predicted speed-up by computing the ratio of cost.

The outputs,

============ timing =============
eigen-mul  mean(ms): 0.0317066 stdev(ms):0.00600852
opti-mul  mean(ms): 0.0101451 stdev(ms):0.00214467
optimal mul speed-up: 3.12532
# optimal muls:27720
# default muls:116800
predicted speed-up by # muls: 4.21356

The predicted speed-up is 4.2 X. The actual speed-up is: 3.1 X.

Not bad.

8. Recommend Reading

If you are interested in C++ template and have too much hair to waste, the slide is an excellent resource.

This slides explains the design of the Eigen library. It discussed lazy evaluation, template meta-programming, expression tree, vectorization, CRTP and so on. High recommend!

9. Summury

In this article, I implemented the optimal matrix multiplication algorithm in compile-time.

  1. I implemented the expression tree as a compile-time type.
  2. I used the expression tree to track the cost to evaluation an expression.
  3. To find the expression with the minimum cost, I implemented a compile-time Min function.
  4. To get all possible expressions, I used a python script to generate them.
  5. In the run-time, I can evaluate the optimal expression by the post-order tree traversal.

It was a fun project. I hope you enjoy it!

To be honest, I am not a expert on C++ template and I can’t find good resources for it. Please let me know if you have something to recommend.

Leave a Reply