提交 76a6cd53 authored 作者: James Bergstra's avatar James Bergstra

new GEMM optimization algorithm

上级 3128a44c
差异被折叠。
from nose.plugins.skip import SkipTest
import traceback
import theano.tensor as T
from theano.gof import Env
from theano.printing import pp
import numpy, theano
from theano.tensor.blas import *
from theano.tensor.blas import _dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix
from theano.tensor.blas import (_dot22, _dot22scalar, res_is_a, _as_scalar, _is_real_matrix,
_gemm_canonicalize, _factor_canonicalized)
from unittest import TestCase
from theano.tests import unittest_tools
from copy import copy
......@@ -267,16 +269,24 @@ class Failure(Exception):
class Warning(Exception):
pass
def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()]):
def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, mode='FAST_RUN')
at_least_one_gemm = False
for node in f.maker.env.nodes:
if node.op == T.dot: raise Warning('dot not changed to gemm_inplace in graph')
if node.op == _dot22: raise Warning('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace: at_least_one_gemm = True
assert at_least_one_gemm
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None))
for node in g.maker.env.nodes:
if node.op == gemm_inplace: raise Exception('gemm_inplace in original graph')
graphlen = len(f.maker.env.toposort())
if max_graphlen and (graphlen <= max_graphlen):
theano.printing.debugprint(f)
assert False, 'graphlen=%i>%i'%(graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
......@@ -353,12 +363,76 @@ def test_gemm_opt_double_gemm():
print 'GRAPH', node
raise
def wishlist_gemm_opt():
def test_gemm_canonicalize():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
can = []
_gemm_canonicalize(X + Y + Z, 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, Z)]
can = []
_gemm_canonicalize(X + Y + u, 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), u]
can = []
_gemm_canonicalize(a*X + Y - b*Z*c, 1.0, can, 0)
assert can[0] == (a, X)
assert can[1] == (1.0, Y)
assert can[2][0].owner.op == T.mul
assert can[2][0].owner.inputs[0].owner.op == T.neg
assert can[2][0].owner.inputs[0].owner.inputs[0] == c
assert can[2][0].owner.inputs[1] == b
can = []
_gemm_canonicalize((-d) * X - (a*X + Y - b*Z*c), 1.0, can, 0)
print can
assert can[0][0].owner.op == T.neg
assert can[0][0].owner.inputs[0] == d
assert can[0][1] == X
assert can[1][0].owner.op == T.neg
assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == T.mul
assert can[3][0].owner.inputs == [c,b]
def test_gemm_factor():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
assert [(1.0, X), (1.0, Y), u] == _factor_canonicalized([(1.0, X), (1.0, Y), u])
assert [(2.0, X), u] == _factor_canonicalized([(1.0, X),(1.0, X), u])
def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
u = T.row('u')
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=1)
print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=1)
print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*U)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=3)
def test_gemm_opt_wishlist():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
#with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
just_gemm([X,Y,Z,a,b], [Z + T.dot(X,Y) + T.dot(X,Y)])
def test_gemm_with_vector():
"""Many subgraphs whose dots can be eliminated.
......@@ -423,9 +497,9 @@ def test_inplace1():
# with > 2 terms in the overall addition
f = inplace_func([X,Y,Z,a,b],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN')
# gemm_inplace should operate in-place on (Z+Z)
if (not gemm_inplace in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm_inplace in graph')
theano.printing.debugprint(f)
# it doesn't work inplace because we didn't mark Z as mutable input
assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
def test_dot22():
if config.mode == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论