提交 76a6cd53 authored 作者: James Bergstra's avatar James Bergstra

new GEMM optimization algorithm

上级 3128a44c
...@@ -525,6 +525,9 @@ def _is_real_matrix(res): ...@@ -525,6 +525,9 @@ def _is_real_matrix(res):
and res.type.broadcastable[1] == False #cope with tuple vs. list and res.type.broadcastable[1] == False #cope with tuple vs. list
def _as_isolated_scalar_times_matrix(res): def _as_isolated_scalar_times_matrix(res):
"""Returns (scalar_var, matrix_var) on success else None
"""
# isolated means that there is only one client of the result 'res'
if res_is_a(res, T.mul, 1): if res_is_a(res, T.mul, 1):
if len(res.owner.inputs) == 2: if len(res.owner.inputs) == 2:
L, R = res.owner.inputs L, R = res.owner.inputs
...@@ -546,6 +549,11 @@ def _as_isolated_scalar_times_matrix(res): ...@@ -546,6 +549,11 @@ def _as_isolated_scalar_times_matrix(res):
else: else:
return None return None
if len(matrices) == 1: if len(matrices) == 1:
if len(scalars) == 0:
rval = (1.0, matrices[0])
elif len(scalars) == 1:
rval = (scalars[0], matrices[0])
else:
rval = (T.mul(*scalars), matrices[0]) rval = (T.mul(*scalars), matrices[0])
return rval return rval
...@@ -553,7 +561,9 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -553,7 +561,9 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip #print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
#EXPRESSION: (beta * L) + (alpha * M) #EXPRESSION: (beta * L) + (alpha * M)
if res_is_a(M, _dot22, 1): # we've already checked the client counts, now just make the type check.
####if res_is_a(M, _dot22, 1):
if M.owner and M.owner.op == _dot22:
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
#print 'GEMM 0', rval, beta, L, alpha, M #print 'GEMM 0', rval, beta, L, alpha, M
...@@ -574,17 +584,14 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -574,17 +584,14 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
#EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v))))) #EXPRESSION: (beta * L) + (alpha * ((b*dot(x,y) + (a * dot(u, v)))))
#EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v)) #EXPRESSION: (beta * L) + (alpha*b*dot(x,y)) + (alpha * a * dot(u, v))
rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)] rval = [gemm_no_inplace(gemm_no_inplace(L, alpha * b, x, y, beta), alpha * a, u, v, 1.0)]
print 'GEMM 1', rval
return rval return rval
if (G is L): if (G is L):
#EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v)) #EXPRESSION: (beta * L) + (alpha*b*L) + (alpha * a * dot(u, v))
rval = [gemm_no_inplace(L, alpha*a, u, v, alpha * b + beta)] rval = [gemm_no_inplace(L, alpha*a, u, v, alpha * b + beta)]
print 'GEMM 2', rval
return rval return rval
if (1.0 != alpha): if (1.0 != alpha):
#at the very least, move the alpha inside the gemm_no_inplace #at the very least, move the alpha inside the gemm_no_inplace
rval = [beta * L + gemm_no_inplace(G, alpha * a, u, v, alpha * b)] rval = [beta * L + gemm_no_inplace(G, alpha * a, u, v, alpha * b)]
print 'GEMM 3', rval
return rval return rval
if recurse_flip: if recurse_flip:
...@@ -592,43 +599,174 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True): ...@@ -592,43 +599,174 @@ def _beta_L_plus_alpha_M(beta, L, alpha, M, recurse_flip = True):
else: else:
return False return False
def _gemm_from_node(node):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
""" def _gemm_canonicalize(r, scale, rval, maxclients):
if node.op == T.sub: # Tries to interpret node as a sum of scalars * matrices
L, R = node.inputs def scaled(thing):
if not _is_real_matrix(L): if scale == 1:
return False return thing
if not _is_real_matrix(R): if scale == -1:
return False return -thing
else:
return scale*thing
if (tuple(r.type.broadcastable) != (False, False) or
r.type.dtype not in ('float32', 'float64', 'complex64', 'complex128')):
rval.append(scaled(r))
return rval
if maxclients and len(getattr(r,'clients',[])) > maxclients:
rval.append((scale, r))
return rval
if r.owner and r.owner.op == T.sub:
_gemm_canonicalize(r.owner.inputs[0], scale, rval, 1)
_gemm_canonicalize(r.owner.inputs[1], -scale, rval, 1)
elif r.owner and r.owner.op == T.add:
for i in r.owner.inputs:
_gemm_canonicalize(i, scale, rval, 1)
elif r.owner and r.owner.op == T.neg:
_gemm_canonicalize(r.owner.inputs[0], -scale, rval, 1)
elif r.owner and r.owner.op == T.mul:
scalars = []
matrices = []
for i in r.owner.inputs:
if numpy.all(i.type.broadcastable):
while i.owner and isinstance(i.owner.op, T.DimShuffle):
i = i.owner.inputs[0]
if i.type.broadcastable:
scalars.append(i.dimshuffle())
else:
scalars.append(i)
elif _is_real_matrix(i):
matrices.append(i)
else:
# just put the original arguments as in the base case
rval.append((scale,r))
return rval
if len(matrices)==1:
m = matrices[0]
if len(scalars) == 0:
_gemm_canonicalize(m, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(m, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(m, T.mul(scaled(scalars[0]), *scalars[1:]), rval, 1)
else: #there are many matrices... lets not open this up
rval.append((scale,r))
else:
rval.append((scale,r))
return rval
tmp = _as_isolated_scalar_times_matrix(L) def _factor_canonicalized(lst):
# remove duplicates from canonicalized list
# we only delete out of the right end of the list,
# once i has touched a list element, it is permantent
lst = list(lst)
#print 'FACTOR', lst
#for (a,b) in lst:
#theano.printing.debugprint(a)
#theano.printing.debugprint(b)
i = 0
while i < len(lst)-1:
try: try:
sL, mL = tmp s_i,M_i = lst[i]
except: except:
sL, mL = 1.0, L i += 1
continue
j = i+1
while j < len(lst):
try:
s_j,M_j = lst[j]
except:
j += 1
continue
if M_i is M_j:
s_i = s_i + s_j
lst[i] = (s_i, M_i)
del lst[j]
else:
j += 1
i+=1
return lst
def _gemm_from_factored_list(lst):
"""Returns None, or a list to replace node.outputs
"""
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in xrange(len(lst) - 1):
try:
s_i,M_i = lst[i]
except:
continue
for j in xrange(i+1, len(lst)):
tmp = _as_isolated_scalar_times_matrix(R)
try: try:
sR, mR = tmp s_j, M_j = lst[j]
except: except:
sR, mR = 1.0, R continue
rval = _beta_L_plus_alpha_M(sL, mL, -sR, mR)
#print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list = _beta_L_plus_alpha_M(s_i, M_i, s_j, M_j)
if gemm_of_sM_list:
#print 'GOT IT', gemm_of_sM_list
def item_to_var(t):
try: s,M = t
except: return t
if s == 1: return M
if s == -1: return -M
return s*M
assert len(gemm_of_sM_list) == 1
add_inputs = [item_to_var(input)
for k, input in enumerate(lst) if k not in (i,j)]
add_inputs.extend(gemm_of_sM_list)
if len(add_inputs) > 1:
return [T.add(*add_inputs)]
else:
return add_inputs
def _gemm_from_node2(node):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
"""
lst = []
_gemm_canonicalize(node.outputs[0], 1.0, lst, 0)
if len(lst) > 1:
lst = _factor_canonicalized(lst)
rval = _gemm_from_factored_list(lst)
return rval return rval
if node.op == T.add:
# arguments of the form scalar * matrix
sM_list = []
# arguments that can be interpreted as scalar * matrix
sM_orig = []
# arguments not of the form scalar * matrix (i.e., vectors, scalars) def inputs_as_scalar_times_matrix(node):
other_inputs = []
# try to interpret an expression as a sum of scalar * matrix terms plus an 'other' term.
# This function *could* recurse and flatten sub and add hierarchies, but it doesn't.
# Reason being - if we didn't need intermediate results, the canonizer should already done
# that.
# returns three lists: sM_list, sM_orig, other
# - sM_list is a list of pairs: the interpretation of some terms as scalar,matrix products
# - sM_orig is a list of variables: the originals before interpretation into sM_list
# - other is a list of terms that are not float matrices
op = None
sM_list = []
sM_orig = []
other = []
if node.op == T.add or node.op == T.sub:
op = node.op
for input in node.inputs: for input in node.inputs:
tmp = _as_isolated_scalar_times_matrix(input) tmp = _as_isolated_scalar_times_matrix(input)
if tmp: if tmp:
...@@ -638,10 +776,16 @@ def _gemm_from_node(node): ...@@ -638,10 +776,16 @@ def _gemm_from_node(node):
sM_list.append((1.0, input)) sM_list.append((1.0, input))
sM_orig.append(input) sM_orig.append(input)
else: else:
other_inputs.append(input) other.append(input)
assert len(sM_list) == len(sM_orig) assert len(sM_list) == len(sM_orig)
assert len(sM_list) + len(other_inputs) == len(node.inputs) assert len(sM_list) + len(other) == len(node.inputs)
return op, sM_list, sM_orig, other
def _gemm_from_sM_list(node, sM_list, sM_orig, other_inputs):
"""Returns None, or a list to replace node.outputs
"""
if len(sM_list) == 2: if len(sM_list) == 2:
(sL, mL), (sR, mR) = sM_list (sL, mL), (sR, mR) = sM_list
gemm_of_sM_list = _beta_L_plus_alpha_M(sL, mL, sR, mR) gemm_of_sM_list = _beta_L_plus_alpha_M(sL, mL, sR, mR)
...@@ -666,21 +810,56 @@ def _gemm_from_node(node): ...@@ -666,21 +810,56 @@ def _gemm_from_node(node):
new_add_inputs = (inputs_without_ij + gemm_of_sM_list + other_inputs) new_add_inputs = (inputs_without_ij + gemm_of_sM_list + other_inputs)
if False: #SUPER DEBUG MODE :(
if len(new_add_inputs) + 1 != len(node.inputs):
print 'inputs', node.inputs
print 'sM, other', sM_list, other_inputs
print 'i,j', i, j
print 'gemm', gemm_of_sM_list
print 'without ij', inputs_without_ij
print 'new inputs', new_add_inputs
sys.exit(1)
# this should be True because we've combined a pair of arguments # this should be True because we've combined a pair of arguments
# into a single GEMM # into a single GEMM
assert len(new_add_inputs) + 1 == len(node.inputs) assert len(new_add_inputs) + 1 == len(node.inputs)
return [T.add(*new_add_inputs)] return [T.add(*new_add_inputs)]
return False
def _gemm_from_node(node):
"""
:todo: In many expressions, there are many ways to turn it into a gemm. For example
dot(a,b) + c + d. This function should return all of them, so that if one version of gemm
causes a cycle in the graph, then another application of gemm can be tried.
"""
op, sM_list, sM_orig, other_inputs = inputs_as_scalar_times_matrix(node)
if op == T.sub and len(sM_list)==2:
(sL, mL), (sR,mR) = sM_list
rval = _gemm_from_sM_list([(sL, mL), (-sR,mR)], None, None)
if rval:
return rval
#theano.printing.debugprint(node.outputs[0], depth=6)
if len(sM_orig[1].clients)==1:
# Canonicalize this subgraph
# There is a form of Gemm that escapes the approach above
# g*W - (a * (e*dot(b,c) + d * W + X))
#
# -> gemm(W, -a*e, b, c, g-a*d) - a*X
#
# In this case g=sL W=mL, and a=sR. We must see if mR is a add() or a sub, in which
# one of the arguments is a scaled version of W a.k.a mL
Rop, RsM_list, RsM_orig, Rother_inputs = inputs_as_scalar_times_matrix(mR.owner)
RsM_list_that_is_mL = [s for (s,m) in RsM_list if m is mL]
if RsM_list_that_is_mL and Rop == T.add:
pass
#g= sL - T.mul(sR,*RsM_list_that_is_mL)
#rval = _gemm_from_sM_list(
#[(g,mL)] + []]
#]
#)
#if Rop == T.add:
#rval = _beta_L_plus_alpha_M(
#L=mL,
#alpha=sR,
#R=T.)
return rval
if op == T.add:
return _gemm_from_sM_list(sM_list, sM_orig, other_inputs)
class GemmOptimizer(Optimizer): class GemmOptimizer(Optimizer):
"""Graph optimizer for inserting Gemm operations""" """Graph optimizer for inserting Gemm operations"""
...@@ -698,7 +877,8 @@ class GemmOptimizer(Optimizer): ...@@ -698,7 +877,8 @@ class GemmOptimizer(Optimizer):
did_something = False did_something = False
nodelist.reverse() nodelist.reverse()
for node in nodelist: for node in nodelist:
new_outputs = _gemm_from_node(node) #new_outputs = _gemm_from_node(node)
new_outputs = _gemm_from_node2(node)
if new_outputs: if new_outputs:
assert len(new_outputs) == len(node.outputs) assert len(new_outputs) == len(node.outputs)
try: try:
......
from nose.plugins.skip import SkipTest
import traceback import traceback
import theano.tensor as T import theano.tensor as T
from theano.gof import Env from theano.gof import Env
from theano.printing import pp from theano.printing import pp
import numpy, theano import numpy, theano
from theano.tensor.blas import * from theano.tensor.blas import *
from theano.tensor.blas import _dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix,
_gemm_canonicalize, _factor_canonicalized)
from unittest import TestCase from unittest import TestCase
from theano.tests import unittest_tools from theano.tests import unittest_tools
from copy import copy from copy import copy
...@@ -267,16 +269,24 @@ class Failure(Exception): ...@@ -267,16 +269,24 @@ class Failure(Exception):
class Warning(Exception): class Warning(Exception):
pass pass
def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]): def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
try: try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, mode='FAST_RUN') f = inplace_func([Param(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
at_least_one_gemm = False
for node in f.maker.env.nodes: for node in f.maker.env.nodes:
if node.op == T.dot: raise Warning('dot not changed to gemm_inplace in graph') if node.op == T.dot: raise Warning('dot not changed to gemm_inplace in graph')
if node.op == _dot22: raise Warning('_dot22 not changed to gemm_inplace in graph') if node.op == _dot22: raise Warning('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace: at_least_one_gemm = True
assert at_least_one_gemm
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None)) g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None))
for node in g.maker.env.nodes: for node in g.maker.env.nodes:
if node.op == gemm_inplace: raise Exception('gemm_inplace in original graph') if node.op == gemm_inplace: raise Exception('gemm_inplace in original graph')
graphlen = len(f.maker.env.toposort())
if max_graphlen and (graphlen <= max_graphlen):
theano.printing.debugprint(f)
assert False, 'graphlen=%i>%i'%(graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
...@@ -353,12 +363,76 @@ def test_gemm_opt_double_gemm(): ...@@ -353,12 +363,76 @@ def test_gemm_opt_double_gemm():
print 'GRAPH', node print 'GRAPH', node
raise raise
def wishlist_gemm_opt():
def test_gemm_canonicalize():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
can = []
_gemm_canonicalize(X + Y + Z, 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, Z)]
can = []
_gemm_canonicalize(X + Y + u, 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), u]
can = []
_gemm_canonicalize(a*X + Y - b*Z*c, 1.0, can, 0)
assert can[0] == (a, X)
assert can[1] == (1.0, Y)
assert can[2][0].owner.op == T.mul
assert can[2][0].owner.inputs[0].owner.op == T.neg
assert can[2][0].owner.inputs[0].owner.inputs[0] == c
assert can[2][0].owner.inputs[1] == b
can = []
_gemm_canonicalize((-d) * X - (a*X + Y - b*Z*c), 1.0, can, 0)
print can
assert can[0][0].owner.op == T.neg
assert can[0][0].owner.inputs[0] == d
assert can[0][1] == X
assert can[1][0].owner.op == T.neg
assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == T.mul
assert can[3][0].owner.inputs == [c,b]
def test_gemm_factor():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
assert [(1.0, X), (1.0, Y), u] == _factor_canonicalized([(1.0, X), (1.0, Y), u])
assert [(2.0, X), u] == _factor_canonicalized([(1.0, X),(1.0, X), u])
def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=1)
print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=1)
print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*U)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=3)
def test_gemm_opt_wishlist():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term #with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
def test_gemm_with_vector(): def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated. """Many subgraphs whose dots can be eliminated.
...@@ -423,9 +497,9 @@ def test_inplace1(): ...@@ -423,9 +497,9 @@ def test_inplace1():
# with > 2 terms in the overall addition # with > 2 terms in the overall addition
f = inplace_func([X,Y,Z,a,b], f = inplace_func([X,Y,Z,a,b],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN') [Z + Z + T.dot(X,Y)], mode='FAST_RUN')
# gemm_inplace should operate in-place on (Z+Z) theano.printing.debugprint(f)
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]): # it doesn't work inplace because we didn't mark Z as mutable input
raise Failure('no gemm_inplace in graph') assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
def test_dot22(): def test_dot22():
if config.mode == 'FAST_COMPILE': if config.mode == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论