Optimal Matrix Chain Multiplication

1. What’s in this article?

The problem

Consider a matrix multiplication problem.

\(A B C \)

Where \(A, B, C\) are matrices. \(A \in R^{(40, 30)}\), \(B \in R^{(30, 20)}\), \(C \in R^{(20, 10)}\).

For a single matrix multiplication \(X Y\) where \(X \in R^{(n, m)}\), \(Y \in R^{(m, k)}\), by the definition of matrix multiplication,

for i=1 to n
   for j=1 to k    
     for k=1 to m
         res[i][j] = res[i][j]+X[i][k]*Y[k][j]

The total number of floating point multiplication for \(XY\) is \(nmk\).

For \(ABC\), the total number of floating point multiplications can be broken into 2 parts,

\(M = A B\), #-muls = 40 * 30 * 20 = 24000

\(res = M C\), #-muls = 40 * 20 * 10 = 8000

The total number of multiplication is 32000.

However, If we evaluation the matrix multiplication is a different order \(A (B C)\), the total number of multiplications can be computed by,

\(M = (B C)\), #-muls = 30 * 20 * 10 = 6000

\(res = A M\), #-muls = 40 * 30 * 10 = 12000

So the total number of multiplications to compute \(A (B C)\) is 18000.

The matrix chain multiplication is common in practice. For example the Riccati equation,

If we find the optimal matrix multiplication order, we can significantly boost the performance of the numerical computation.

The goal

We want to find the optimal matrix multiplication order such that the total number of floating-point multiplication is minimized.

The outline

I will explain and implement the algorithm to find the optimal matrix multiplication order. In the below sections, I will discuss

  1. How to represent a expression (e.g. \(A * (B * C)\)) as a expression tree.
  2. Given an expression tree, we can compute the number of multiplications for an expression.
  3. A simple way to find the optimal expression tree is to iterate all expression trees and compute its number of multiplications.
  4. The above brute force algorithm can be simplified by dynamic programming (or memorization if you like).

The implementation link.

2. The Expression Tree

For the Matrix Multiplication Problem, an expression is simply the multiplication of matrices. E.g. The below expression computes the product of 4 matrices.

\(A * B * C * D\)

An expression can be represent as a tree. We call it the expression tree.

The expression tree for expression \(A * B * C * D\) can be,

    _@ 
   /  \
  _@  D
 /  \  
 @  C  
/ \    
A B 

The expression tree explicitly defined the order of the multiplication. For example, in the expression tree above, the order of the multiplication is \(((A * B) *C) * D\)

The evaluation of an expression tree is simply the post-order traversal of a tree. For example, for the above tree, we first evaluated the product node at line 5; then the product node at line 3; then the product node at line 1.

2.1 Tree Implementation

The expression tree for the matrix multiplication problem can be represented as a binary tree.

Since we only consider matrix multiplication, leaf nodes for the tree always represent matrices and non-leaf nodes always represent multiplication operations.

In python the tree can be implemented as,

class Node:
    def __init__(self, type):
        self.node_type = type
        self.shape = None
        self.expression_string = None


class OpNode(Node):
    def __init__(self, op):
        super().__init__('OpNode')
        self.left = None
        self.right = None
        self.op = op


class VarNode(Node):
    def __init__(self, var):
        super().__init__('VarNode')
        self.var = var
        self.shape = var.data.shape
        self.expression_string = var.name
        

VarNode represents leaf nodes. OpNode represents non-leaf nodes. Node is a base class for them.

With the tree defined, we can print the tree like this,

def print_tree(root):
    '''
        https://stackoverflow.com/questions/34012886/print-binary-tree-level-by-level-in-python
    '''
    def _display_aux(root):
        # No child.
        if root.node_type == 'VarNode':
            line = root.var.name
            width = len(line)
            height = 1
            middle = width // 2
            return [line], width, height, middle

        # Two children.
        left, n, p, x = _display_aux(root.left)
        right, m, q, y = _display_aux(root.right)
        s = root.op
        u = len(s)
        first_line = (x + 1) * ' ' + (n - x - 1) * \
            '_' + s + y * '_' + (m - y) * ' '
        second_line = x * ' ' + '/' + \
            (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
        if p < q:
            left += [n * ' '] * (q - p)
        elif q < p:
            right += [m * ' '] * (p - q)
        zipped_lines = zip(left, right)
        lines = [first_line, second_line] + \
            [a + u * ' ' + b for a, b in zipped_lines]
        return lines, n + m + u, max(p, q) + 2, n + u // 2
    
    lines, _, _, _ = _display_aux(root)
    for line in lines:
        print(line)

It’s basically a post-order traversal (or divide-and-conquer). We computed the output strings for the two children and merge them together.

2.2 Build a expression tree

Given a matrix multiplication expression, how can we convert it to an expression tree?

i.e. Given \(A @ B @ C @ D @ E @ F\) (@ is the matrix multiplication in Python), how can we compute the expression tree below?

 @_        
/  \       
A  @_      
  /  \     
  B  @_    
    /  \   
    C  @_  
      /  \ 
      D  @ 
        / \
        E F

This problem is the same as the Lexical analysis and Syntax analysis problem for Compilers.

Lexical analysis

We first need to break the input string into tokens. Tokens mean useful symbols.

For example, an input string is “A @ B @ C @ D @ E @ F”, the lexical analyser break the string into a list of tokens [“A”, “@”, “B”,”@”,”C”,”@”,”D”,”@”,”E”,”@”,”F”].

The Lexical analysis can be done by regular expression matching. I will assume the Lexical analysis is already done.

Syntax analysis

Given a list of tokens, the syntax analyzer uses a pre-defined syntax to parse tokens into an expression tree.

Implementation

A simple implementation to parse [“A”, “@”, “B”,”@”,”C”,”@”,”D”,”@”,”E”,”@”,”F”] into a tree is,

def build_expression_tree_simple(vars):
    '''
        build a single tree for a list of mats
    '''
    if len(vars) == 1:
        return VarNode(vars[0])

    op_node = OpNode('@')
    op_node.left = VarNode(vars[0])
    op_node.right = build_expression_tree_simple(vars[1:])
    op_node.expression_string = \
        '({}*{})'.format(op_node.left.expression_string,
                         op_node.right.expression_string)
    return op_node

Using the implementation, the expression tree for [“A”, “@”, “B”,”@”,”C”,”@”,”D”,”@”,”E”,”@”,”F”] is,

 @_        
/  \       
A  @_      
  /  \     
  B  @_    
    /  \   
    C  @_  
      /  \ 
      D  @ 
        / \
        E F

In the implementation, I ignored @ because it’s the only operation.

The idea is to recursively build the expression tree.

If the input has only one matrix, we just need to return a VarNode to represent the matrix.

If the input has more than one matrix, we construct a OpNode to represent a multiplication operation. The left child of the OpNode is the first input matrix. For the right child, we need to build another expression tree for the input without the first matrix. It’s done by calling the function recursively.

Precedence and Associativity

You may already find out that there are many ways to build the expression tree. For example, in the above implementation, we can build the right child using the first half of matrices and build the left child using another half.

def build_expression_tree_mid(vars):
    '''
        build a single tree for a list of mats
    '''
    if len(vars) == 1:
        return VarNode(vars[0])
    op_node = OpNode('@')
    var_len = len(vars)
    op_node.left = build_expression_tree_simple(vars[: var_len // 2])
    op_node.right = build_expression_tree_simple(vars[var_len // 2 :])
    op_node.expression_string = \
        '({}*{})'.format(op_node.left.expression_string,
                         op_node.right.expression_string)
    return op_node

The tree will be,

  ___@_    
 /     \   
 @_    @_  
/  \  /  \ 
A  @  D  @ 
  / \   / \
  B C   E F

The way parser handles the ambiguity is to define the Precedence and the Associativity for each operation. When building the tree, the parser grows the tree by rules defined by the Precedence and the Associativity.

Luckily, for the optimal matrix multiplication problem, we don’t need to consider the Precedence and the Associativity.

3. Compute the Number of Multiplications

Given an expression tree for matrix multiplication, we can compute the total number of floating-point multiplication for it by doing a post-order traversal.

def count_num_muls(root):
    if root.node_type == 'VarNode':
        return 0

    assert root.node_type == 'OpNode'
    left_muls = count_num_muls(root.left)
    right_muls = count_num_muls(root.right)

    lr, lc = root.left.shape
    rr, rc = root.right.shape
    assert lc == rr, 'matrix dim mismatch'
    root.shape = (lr, rc)
    cur_muls = lc * lr * rc

    total_muls = left_muls + right_muls + cur_muls

    return total_muls

Firstly, We compute the number of muls to compute the left and the right children. Then we compute the number of muls to multiply the left and right children. In the end, we return the total number of muls.

One thing to keep in mind is we need to keep updating the shape of the OpNode node because the result of an OpNode is a matrix which can be the input to other OpNode.

4. Find the Optimal Expression Tree

Recall the goal of the article is to find the best order of matrix multiplication such that the total number of floating-point multiplications is minimized.

A naive method is to generate all possible expression trees for a list of tokens and compute the number of multiplications for each tree.

4.1 Generate all Expression Trees

Given a list of tokens [“A”, “B”,”C”,”D”,”E”,”F”], how can we generate all the expression tree?

This can be done by recursions.

def build_all_expression_tree(vars):
    '''
        build all trees for a list of mats
    '''
    if len(vars) == 1:
        return [VarNode(vars[0])]

    trees = []

    for i in range(1, len(vars)):
        left_trees = build_all_expression_tree(vars[:i])
        right_trees = build_all_expression_tree(vars[i:])

        for l in left_trees:
            for r in right_trees:
                root = OpNode('@')
                root.left = l
                root.right = r
                root.expression_string = \
                    '({}@{})'.format(root.left.expression_string,
                                     root.right.expression_string)
                trees.append(root)
    return trees

The function outputs all possible expression trees for a list of tokens.

In the function, we iterate over all possible partitions for tokens. We call the function recursively to generate all possible sub-trees for the left-partitioned tokens and the right-partitioned tokens. Then, a combination of the left and the right sub-trees is a valid expression tree.

4.2 Find the Best Tree

To find the best tree, we just need to generate all trees and apply the function in section 3 to compute the number of muls. In the end, we select the tree with the least amount of muls.

def test_all_tree():
    vars = gen_test_data()

    trees = build_all_expression_tree(vars)
    muls = [count_num_muls(t) for t in trees]]
    for t, m in zip(trees, muls):
        print('for tree:', t.expression_string)
        display_tree(t)
        print('# muls:', m)

    min_tree, min_muls = min(zip(trees, muls), key=lambda x: x[1])
    print('optimal expression:', min_tree.expression_string)
    print('optimal # muls:', min_muls)

For tokens,

def gen_test_data():
    A = Variable('A', np.ones([20, 20]))
    B = Variable('B', np.ones([20, 35]))
    C = Variable('C', np.ones([35, 100]))
    D = Variable('D', np.ones([100, 360]))
    E = Variable('E', np.ones([360, 10]))
    F = Variable('F', np.ones([10, 10]))
    vars = [A, B, C, D, E, F]
    return vars

part of the results are,

....

for tree: (((A@(B@(C@D)))@E)@F)
        _@ 
       /  \
  _____@  F
 /      \  
 @_     E  
/  \       
A  @_      
  /  \     
  B  @     
    / \    
    C D    
# muls: 1730000

for tree: (((A@((B@C)@D))@E)@F)
        _@ 
       /  \
  _____@  F
 /      \  
 @___   E  
/    \     
A   _@     
   /  \    
   @  D    
  / \      
  B C      
# muls: 1008000

for tree: ((((A@B)@(C@D))@E)@F)
        _@ 
       /  \
    ___@  F
   /    \  
  _@_   E  
 /   \     
 @   @     
/ \ / \    
A B C D    
# muls: 1600000

for tree: ((((A@(B@C))@D)@E)@F)
        _@ 
       /  \
      _@  F
     /  \  
  ___@  E  
 /    \    
 @_   D    
/  \       
A  @       
  / \      
  B C      
# muls: 904000

for tree: (((((A@B)@C)@D)@E)@F)
        _@ 
       /  \
      _@  F
     /  \  
    _@  E  
   /  \    
  _@  D    
 /  \      
 @  C      
/ \        
A B        
# muls: 878000

optimal expression: (A@((B@(C@(D@E)))@F))
optimal # muls: 408000

5. Dynamic Programming by Memorization

We can improve the above algorithm by memorizing computed results. e.g If the number of multiplications for A@(B@C) is computed, we can use a dictionary to save it. When we see the expression again, we can just read the memorized result.

The implementation can be,

def count_num_muls(trees):
    dp_map = {}

    def count_num_muls_internal(root):
        s = root.expression_string
        if s in dp_map:
            count, shape = dp_map[s]
            # update the shape of mul
            root.shape = shape
            return count

        if root.node_type == 'VarNode':
            return 0

        assert root.node_type == 'OpNode'
        left_muls = count_num_muls_internal(root.left)
        right_muls = count_num_muls_internal(root.right)

        lr, lc = root.left.shape
        rr, rc = root.right.shape
        assert lc == rr, 'matrix dim mismatch'
        root.shape = (lr, rc)
        cur_muls = lc * lr * rc

        total_muls = left_muls + right_muls + cur_muls
        # need to track the shape of mul
        dp_map[s] = (total_muls, root.shape)
        return total_muls

    for t in trees:
        yield count_num_muls_internal(t)

6. Timing

Consider this Numpy expression A @ B @ C @ D @ E @ F. Its tokens and matrix sizes are,

def gen_test_data():
    A = Variable('A', np.ones([20, 20]))
    B = Variable('B', np.ones([20, 35]))
    C = Variable('C', np.ones([35, 100]))
    D = Variable('D', np.ones([100, 360]))
    E = Variable('E', np.ones([360, 10]))
    F = Variable('F', np.ones([10, 10]))
    vars = [A, B, C, D, E, F]
    return vars

The optimal matrix multiplication order (expression tree) is:

optimal expression: (A@((B@(C@(D@E)))@F))
 @_______  
/        \ 
A   _____@ 
   /      \
   @_     F
  /  \     
  B  @_    
    /  \   
    C  @   
      / \  
      D E  
optimal # muls: 408000

Whereas for the default order,

for tree: (((((A@B)@C)@D)@E)@F)
        _@ 
       /  \
      _@  F
     /  \  
    _@  E  
   /  \    
  _@  D    
 /  \      
 @  C      
/ \        
A B        
# muls: 878000

I used the python timeit module to compute the time to run each expression in python.

def time_expr():
    print('timming....')
    setup_code = ''' 
from __main__ import gen_test_data
vars = gen_test_data()
data = [v.data for v in vars]
A, B, C, D, E, F = data
'''

    test_code1 = ''' 
val = A@B@C@D@E@F
    '''

    test_code2 = ''' 
val = (A@((B@(C@(D@E)))@F))
    '''
    times = timeit.repeat(setup=setup_code,
                          stmt=test_code1,
                          number=10000)
    print('time for {} is {} ms'.format(test_code1, min(times)))
    times = timeit.repeat(setup=setup_code,
                          stmt=test_code2,
                          number=10000)
    print('time for {} is {} ms'.format(test_code2, min(times)))

The timing results are,

timming....
time for  
val = A@B@C@D@E@F
    is 1.5322212360042613 ms
time for  
val = (A@((B@(C@(D@E)))@F))
    is 0.4744670180079993 ms

(A@((B@(C@(D@E)))@F)) is 3 times faster than A@B@C@D@E@F.

(Although by the operation counts, the speed-up should be around 2 times… The run time for a modern computer is hard to reason about…)

7. My Real Goal and My Failure

The Python implementation generates the optimal expression at run-time. To use this library, we may have to compute the optimal expressions at runtime and replace the expression in the source code later.

In practice, The size of matrices are often known at the compile time.

My goal was to do a C++ template meta-programming to compute the optimal matrix multiplication order in the compile time.

The C++ template meta-programming is similar to compilers, where computation is done before running the program.

For example, the greatest common divisor can be implemented as,

constexpr auto gcd(int a, int b){
  while (b != 0){
    auto t= b;
    b= a % b;
    a= t;
  }
  return a;
}

int main(){
  constexpr int i= gcd(11,121);
}

Because the inputs for gcd are constant at compile time, the result is computed at compile time. The keyword constexpr did the trick. In runtime, the code is equivalent to,

int main(){
  constexpr int i = 11;
}

For the matrix multiplication problem, the size of matrices can be a compiling time constant (e.g. fixed size matrix in Eigen). So ideally we should able to do the optimal matrix multiplication in compile time.

My plan was to do some crazy template stuff. However, I can’t make it work. Creating a constexpr data structure seems tricky. Let me know if you have any ideas.

update, I solved the problem in compile-time in this article.

8. The End

Interestingly, I can’t find a C++ lib that implements the above algorithm at compile time. Eigen seems like did something to analyze expressions. But it didn’t do anything to optimize the matrix multiplication order. Please let me know if there is such a library.

If you want to learn more about compiler, please take this online course: https://www.edx.org/course/compilers.

For more information about C++ template meta-programming, this video gave a good overview.

Leave a Reply