1. What’s in this article?
The problem
Consider a matrix multiplication problem.
\(A B C \)
Where \(A, B, C\) are matrices. \(A \in R^{(40, 30)}\), \(B \in R^{(30, 20)}\), \(C \in R^{(20, 10)}\).
For a single matrix multiplication \(X Y\) where \(X \in R^{(n, m)}\), \(Y \in R^{(m, k)}\), by the definition of matrix multiplication,
for i=1 to n for j=1 to k for k=1 to m res[i][j] = res[i][j]+X[i][k]*Y[k][j]
The total number of floating point multiplication for \(XY\) is \(nmk\).
For \(ABC\), the total number of floating point multiplications can be broken into 2 parts,
\(M = A B\), #-muls = 40 * 30 * 20 = 24000
\(res = M C\), #-muls = 40 * 20 * 10 = 8000
The total number of multiplication is 32000.
However, If we evaluation the matrix multiplication is a different order \(A (B C)\), the total number of multiplications can be computed by,
\(M = (B C)\), #-muls = 30 * 20 * 10 = 6000
\(res = A M\), #-muls = 40 * 30 * 10 = 12000
So the total number of multiplications to compute \(A (B C)\) is 18000.
The matrix chain multiplication is common in practice. For example the Riccati equation,
If we find the optimal matrix multiplication order, we can significantly boost the performance of the numerical computation.
The goal
We want to find the optimal matrix multiplication order such that the total number of floating-point multiplication is minimized.
The outline
I will explain and implement the algorithm to find the optimal matrix multiplication order. In the below sections, I will discuss
- How to represent a expression (e.g. \(A * (B * C)\)) as a expression tree.
- Given an expression tree, we can compute the number of multiplications for an expression.
- A simple way to find the optimal expression tree is to iterate all expression trees and compute its number of multiplications.
- The above brute force algorithm can be simplified by dynamic programming (or memorization if you like).
The implementation link.
2. The Expression Tree
For the Matrix Multiplication Problem, an expression is simply the multiplication of matrices. E.g. The below expression computes the product of 4 matrices.
\(A * B * C * D\)
An expression can be represent as a tree. We call it the expression tree.
The expression tree for expression \(A * B * C * D\) can be,
_@ / \ _@ D / \ @ C / \ A B
The expression tree explicitly defined the order of the multiplication. For example, in the expression tree above, the order of the multiplication is \(((A * B) *C) * D\)
The evaluation of an expression tree is simply the post-order traversal of a tree. For example, for the above tree, we first evaluated the product node at line 5; then the product node at line 3; then the product node at line 1.
2.1 Tree Implementation
The expression tree for the matrix multiplication problem can be represented as a binary tree.
Since we only consider matrix multiplication, leaf nodes for the tree always represent matrices and non-leaf nodes always represent multiplication operations.
In python the tree can be implemented as,
class Node: def __init__(self, type): self.node_type = type self.shape = None self.expression_string = None class OpNode(Node): def __init__(self, op): super().__init__('OpNode') self.left = None self.right = None self.op = op class VarNode(Node): def __init__(self, var): super().__init__('VarNode') self.var = var self.shape = var.data.shape self.expression_string = var.name
VarNode
represents leaf nodes. OpNode
represents non-leaf nodes. Node
is a base class for them.
With the tree defined, we can print the tree like this,
def print_tree(root): ''' https://stackoverflow.com/questions/34012886/print-binary-tree-level-by-level-in-python ''' def _display_aux(root): # No child. if root.node_type == 'VarNode': line = root.var.name width = len(line) height = 1 middle = width // 2 return [line], width, height, middle # Two children. left, n, p, x = _display_aux(root.left) right, m, q, y = _display_aux(root.right) s = root.op u = len(s) first_line = (x + 1) * ' ' + (n - x - 1) * \ '_' + s + y * '_' + (m - y) * ' ' second_line = x * ' ' + '/' + \ (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' ' if p < q: left += [n * ' '] * (q - p) elif q < p: right += [m * ' '] * (p - q) zipped_lines = zip(left, right) lines = [first_line, second_line] + \ [a + u * ' ' + b for a, b in zipped_lines] return lines, n + m + u, max(p, q) + 2, n + u // 2 lines, _, _, _ = _display_aux(root) for line in lines: print(line)
It’s basically a post-order traversal (or divide-and-conquer). We computed the output strings for the two children and merge them together.
2.2 Build a expression tree
Given a matrix multiplication expression, how can we convert it to an expression tree?
i.e. Given \(A @ B @ C @ D @ E @ F\) (@ is the matrix multiplication in Python), how can we compute the expression tree below?
@_ / \ A @_ / \ B @_ / \ C @_ / \ D @ / \ E F
This problem is the same as the Lexical analysis and Syntax analysis problem for Compilers.
Lexical analysis
We first need to break the input string into tokens. Tokens mean useful symbols.
For example, an input string is “A @ B @ C @ D @ E @ F”, the lexical analyser break the string into a list of tokens [“A”, “@”, “B”,”@”,”C”,”@”,”D”,”@”,”E”,”@”,”F”].
The Lexical analysis can be done by regular expression matching. I will assume the Lexical analysis is already done.
Syntax analysis
Given a list of tokens, the syntax analyzer uses a pre-defined syntax to parse tokens into an expression tree.
Implementation
A simple implementation to parse [“A”, “@”, “B”,”@”,”C”,”@”,”D”,”@”,”E”,”@”,”F”] into a tree is,
def build_expression_tree_simple(vars): ''' build a single tree for a list of mats ''' if len(vars) == 1: return VarNode(vars[0]) op_node = OpNode('@') op_node.left = VarNode(vars[0]) op_node.right = build_expression_tree_simple(vars[1:]) op_node.expression_string = \ '({}*{})'.format(op_node.left.expression_string, op_node.right.expression_string) return op_node
Using the implementation, the expression tree for [“A”, “@”, “B”,”@”,”C”,”@”,”D”,”@”,”E”,”@”,”F”] is,
@_ / \ A @_ / \ B @_ / \ C @_ / \ D @ / \ E F
In the implementation, I ignored @ because it’s the only operation.
The idea is to recursively build the expression tree.
If the input has only one matrix, we just need to return a VarNode
to represent the matrix.
If the input has more than one matrix, we construct a OpNode
to represent a multiplication operation. The left child of the OpNode
is the first input matrix. For the right child, we need to build another expression tree for the input without the first matrix. It’s done by calling the function recursively.
Precedence and Associativity
You may already find out that there are many ways to build the expression tree. For example, in the above implementation, we can build the right child using the first half of matrices and build the left child using another half.
def build_expression_tree_mid(vars): ''' build a single tree for a list of mats ''' if len(vars) == 1: return VarNode(vars[0]) op_node = OpNode('@') var_len = len(vars) op_node.left = build_expression_tree_simple(vars[: var_len // 2]) op_node.right = build_expression_tree_simple(vars[var_len // 2 :]) op_node.expression_string = \ '({}*{})'.format(op_node.left.expression_string, op_node.right.expression_string) return op_node
The tree will be,
___@_ / \ @_ @_ / \ / \ A @ D @ / \ / \ B C E F
The way parser handles the ambiguity is to define the Precedence and the Associativity for each operation. When building the tree, the parser grows the tree by rules defined by the Precedence and the Associativity.
Luckily, for the optimal matrix multiplication problem, we don’t need to consider the Precedence and the Associativity.
3. Compute the Number of Multiplications
Given an expression tree for matrix multiplication, we can compute the total number of floating-point multiplication for it by doing a post-order traversal.
def count_num_muls(root): if root.node_type == 'VarNode': return 0 assert root.node_type == 'OpNode' left_muls = count_num_muls(root.left) right_muls = count_num_muls(root.right) lr, lc = root.left.shape rr, rc = root.right.shape assert lc == rr, 'matrix dim mismatch' root.shape = (lr, rc) cur_muls = lc * lr * rc total_muls = left_muls + right_muls + cur_muls return total_muls
Firstly, We compute the number of muls to compute the left and the right children. Then we compute the number of muls to multiply the left and right children. In the end, we return the total number of muls.
One thing to keep in mind is we need to keep updating the shape of the OpNode
node because the result of an OpNode
is a matrix which can be the input to other OpNode
.
4. Find the Optimal Expression Tree
Recall the goal of the article is to find the best order of matrix multiplication such that the total number of floating-point multiplications is minimized.
A naive method is to generate all possible expression trees for a list of tokens and compute the number of multiplications for each tree.
4.1 Generate all Expression Trees
Given a list of tokens [“A”, “B”,”C”,”D”,”E”,”F”], how can we generate all the expression tree?
This can be done by recursions.
def build_all_expression_tree(vars): ''' build all trees for a list of mats ''' if len(vars) == 1: return [VarNode(vars[0])] trees = [] for i in range(1, len(vars)): left_trees = build_all_expression_tree(vars[:i]) right_trees = build_all_expression_tree(vars[i:]) for l in left_trees: for r in right_trees: root = OpNode('@') root.left = l root.right = r root.expression_string = \ '({}@{})'.format(root.left.expression_string, root.right.expression_string) trees.append(root) return trees
The function outputs all possible expression trees for a list of tokens.
In the function, we iterate over all possible partitions for tokens. We call the function recursively to generate all possible sub-trees for the left-partitioned tokens and the right-partitioned tokens. Then, a combination of the left and the right sub-trees is a valid expression tree.
4.2 Find the Best Tree
To find the best tree, we just need to generate all trees and apply the function in section 3 to compute the number of muls. In the end, we select the tree with the least amount of muls.
def test_all_tree(): vars = gen_test_data() trees = build_all_expression_tree(vars) muls = [count_num_muls(t) for t in trees]] for t, m in zip(trees, muls): print('for tree:', t.expression_string) display_tree(t) print('# muls:', m) min_tree, min_muls = min(zip(trees, muls), key=lambda x: x[1]) print('optimal expression:', min_tree.expression_string) print('optimal # muls:', min_muls)
For tokens,
def gen_test_data(): A = Variable('A', np.ones([20, 20])) B = Variable('B', np.ones([20, 35])) C = Variable('C', np.ones([35, 100])) D = Variable('D', np.ones([100, 360])) E = Variable('E', np.ones([360, 10])) F = Variable('F', np.ones([10, 10])) vars = [A, B, C, D, E, F] return vars
part of the results are,
.... for tree: (((A@(B@(C@D)))@E)@F) _@ / \ _____@ F / \ @_ E / \ A @_ / \ B @ / \ C D # muls: 1730000 for tree: (((A@((B@C)@D))@E)@F) _@ / \ _____@ F / \ @___ E / \ A _@ / \ @ D / \ B C # muls: 1008000 for tree: ((((A@B)@(C@D))@E)@F) _@ / \ ___@ F / \ _@_ E / \ @ @ / \ / \ A B C D # muls: 1600000 for tree: ((((A@(B@C))@D)@E)@F) _@ / \ _@ F / \ ___@ E / \ @_ D / \ A @ / \ B C # muls: 904000 for tree: (((((A@B)@C)@D)@E)@F) _@ / \ _@ F / \ _@ E / \ _@ D / \ @ C / \ A B # muls: 878000 optimal expression: (A@((B@(C@(D@E)))@F)) optimal # muls: 408000
5. Dynamic Programming by Memorization
We can improve the above algorithm by memorizing computed results. e.g If the number of multiplications for A@(B@C) is computed, we can use a dictionary to save it. When we see the expression again, we can just read the memorized result.
The implementation can be,
def count_num_muls(trees): dp_map = {} def count_num_muls_internal(root): s = root.expression_string if s in dp_map: count, shape = dp_map[s] # update the shape of mul root.shape = shape return count if root.node_type == 'VarNode': return 0 assert root.node_type == 'OpNode' left_muls = count_num_muls_internal(root.left) right_muls = count_num_muls_internal(root.right) lr, lc = root.left.shape rr, rc = root.right.shape assert lc == rr, 'matrix dim mismatch' root.shape = (lr, rc) cur_muls = lc * lr * rc total_muls = left_muls + right_muls + cur_muls # need to track the shape of mul dp_map[s] = (total_muls, root.shape) return total_muls for t in trees: yield count_num_muls_internal(t)
6. Timing
Consider this Numpy expression A @ B @ C @ D @ E @ F. Its tokens and matrix sizes are,
def gen_test_data(): A = Variable('A', np.ones([20, 20])) B = Variable('B', np.ones([20, 35])) C = Variable('C', np.ones([35, 100])) D = Variable('D', np.ones([100, 360])) E = Variable('E', np.ones([360, 10])) F = Variable('F', np.ones([10, 10])) vars = [A, B, C, D, E, F] return vars
The optimal matrix multiplication order (expression tree) is:
optimal expression: (A@((B@(C@(D@E)))@F)) @_______ / \ A _____@ / \ @_ F / \ B @_ / \ C @ / \ D E optimal # muls: 408000
Whereas for the default order,
for tree: (((((A@B)@C)@D)@E)@F) _@ / \ _@ F / \ _@ E / \ _@ D / \ @ C / \ A B # muls: 878000
I used the python timeit
module to compute the time to run each expression in python.
def time_expr(): print('timming....') setup_code = ''' from __main__ import gen_test_data vars = gen_test_data() data = [v.data for v in vars] A, B, C, D, E, F = data ''' test_code1 = ''' val = A@B@C@D@E@F ''' test_code2 = ''' val = (A@((B@(C@(D@E)))@F)) ''' times = timeit.repeat(setup=setup_code, stmt=test_code1, number=10000) print('time for {} is {} ms'.format(test_code1, min(times))) times = timeit.repeat(setup=setup_code, stmt=test_code2, number=10000) print('time for {} is {} ms'.format(test_code2, min(times)))
The timing results are,
timming.... time for val = A@B@C@D@E@F is 1.5322212360042613 ms time for val = (A@((B@(C@(D@E)))@F)) is 0.4744670180079993 ms
(A@((B@(C@(D@E)))@F))
is 3 times faster than A@B@C@D@E@F
.
(Although by the operation counts, the speed-up should be around 2 times… The run time for a modern computer is hard to reason about…)
7. My Real Goal and My Failure
The Python implementation generates the optimal expression at run-time. To use this library, we may have to compute the optimal expressions at runtime and replace the expression in the source code later.
In practice, The size of matrices are often known at the compile time.
My goal was to do a C++ template meta-programming to compute the optimal matrix multiplication order in the compile time.
The C++ template meta-programming is similar to compilers, where computation is done before running the program.
For example, the greatest common divisor can be implemented as,
constexpr auto gcd(int a, int b){ while (b != 0){ auto t= b; b= a % b; a= t; } return a; } int main(){ constexpr int i= gcd(11,121); }
Because the inputs for gcd
are constant at compile time, the result is computed at compile time. The keyword constexpr
did the trick. In runtime, the code is equivalent to,
int main(){ constexpr int i = 11; }
For the matrix multiplication problem, the size of matrices can be a compiling time constant (e.g. fixed size matrix in Eigen). So ideally we should able to do the optimal matrix multiplication in compile time.
My plan was to do some crazy template stuff. However, I can’t make it work. Creating a constexpr data structure seems tricky. Let me know if you have any ideas.
8. The End
Interestingly, I can’t find a C++ lib that implements the above algorithm at compile time. Eigen seems like did something to analyze expressions. But it didn’t do anything to optimize the matrix multiplication order. Please let me know if there is such a library.
If you want to learn more about compiler, please take this online course: https://www.edx.org/course/compilers.
For more information about C++ template meta-programming, this video gave a good overview.