提交 f58e14cc authored 作者: Frederic's avatar Frederic

Add a test that show when the gemm optimizer add a gemm, but didn't remove a dot22.

上级 d604b4be
...@@ -45,6 +45,10 @@ def test_dot_eq(): ...@@ -45,6 +45,10 @@ def test_dot_eq():
assert T.Dot() == T.Dot() assert T.Dot() == T.Dot()
def sharedX(x, name):
return theano.shared(numpy.asarray(x, config.floatX), name=name)
class t_gemm(TestCase): class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it """This test suite is supposed to establish that gemm works as it
is supposed to. is supposed to.
...@@ -480,7 +484,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], ...@@ -480,7 +484,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
raise Failure('_dot22 not changed to gemm_inplace in graph') raise Failure('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace: if node.op == gemm_inplace:
nb_gemm += 1 nb_gemm += 1
assert nb_gemm == expected_nb_gemm assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None), g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore') allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.env.nodes: for node in g.maker.env.nodes:
...@@ -712,8 +716,8 @@ def test_gemm_opt_wishlist(): ...@@ -712,8 +716,8 @@ def test_gemm_opt_wishlist():
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar() X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
#with >2 additions of the same T.dot(X,Y term #with >2 additions of the same T.dot(X,Y term
just_gemm([X, Y, Z, a, b], [(b * b) * Z * a + (a * a) * T.dot(X, Y) + just_gemm([X, Y, Z, a, b],
b * T.dot(X, Y)]) [(b * b) * Z * a + (a * a) * T.dot(X, Y) + b * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y) + T.dot(X, Y)]) just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y) + T.dot(X, Y)])
...@@ -763,6 +767,44 @@ def test_gemm_opt_vector_stuff(): ...@@ -763,6 +767,44 @@ def test_gemm_opt_vector_stuff():
raise Failure('gemm_inplace in graph') raise Failure('gemm_inplace in graph')
def test_gemm_unrolled():
batch_size = 100
rep_size = 40
rng = numpy.random.RandomState([1, 2, 3])
for num_rounds in range(1, 10):
W = sharedX(rng.randn(rep_size, rep_size), name='W')
V = sharedX(numpy.zeros((batch_size, rep_size)), name='V')
H = sharedX(numpy.zeros((batch_size, rep_size)), name='H')
G = sharedX(numpy.zeros((batch_size, rep_size)), name='G')
init_V = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_V')
init_H = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_H')
cur_V = V
cur_H = H
def update_V(cur_H):
return T.nnet.sigmoid(T.dot(cur_H, W.T))
def update_H(cur_V):
return T.nnet.sigmoid(T.dot(cur_V, W) + T.dot(G, W.T))
for i in xrange(num_rounds):
cur_V = update_V(cur_H)
cur_H = update_H(cur_V)
unrolled_theano = theano.function([], updates={V: cur_V, H: cur_H},
name='unrolled_theano')
nb_dot = sum([1 for node in unrolled_theano.maker.env.toposort()
if isinstance(node.op, (theano.tensor.Dot,
theano.tensor.blas.Dot22,
theano.tensor.blas.Gemm))])
assert nb_dot == num_rounds * 2 + 1, nb_dot
unrolled_theano()
def test_inplace0(): def test_inplace0():
#should fail to insert gemm_inplace because gemm_inplace would #should fail to insert gemm_inplace because gemm_inplace would
#create cycles #create cycles
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论