提交 f58e14cc authored 作者: Frederic's avatar Frederic

Add a test that show when the gemm optimizer add a gemm, but didn't remove a dot22.

上级 d604b4be
......@@ -45,6 +45,10 @@ def test_dot_eq():
assert T.Dot() == T.Dot()
def sharedX(x, name):
return theano.shared(numpy.asarray(x, config.floatX), name=name)
class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it
is supposed to.
......@@ -480,7 +484,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
raise Failure('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace:
nb_gemm += 1
assert nb_gemm == expected_nb_gemm
assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.env.nodes:
......@@ -712,8 +716,8 @@ def test_gemm_opt_wishlist():
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
#with >2 additions of the same T.dot(X,Y term
just_gemm([X, Y, Z, a, b], [(b * b) * Z * a + (a * a) * T.dot(X, Y) +
b * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b],
[(b * b) * Z * a + (a * a) * T.dot(X, Y) + b * T.dot(X, Y)])
just_gemm([X, Y, Z, a, b], [Z + T.dot(X, Y) + T.dot(X, Y)])
......@@ -763,6 +767,44 @@ def test_gemm_opt_vector_stuff():
raise Failure('gemm_inplace in graph')
def test_gemm_unrolled():
batch_size = 100
rep_size = 40
rng = numpy.random.RandomState([1, 2, 3])
for num_rounds in range(1, 10):
W = sharedX(rng.randn(rep_size, rep_size), name='W')
V = sharedX(numpy.zeros((batch_size, rep_size)), name='V')
H = sharedX(numpy.zeros((batch_size, rep_size)), name='H')
G = sharedX(numpy.zeros((batch_size, rep_size)), name='G')
init_V = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_V')
init_H = sharedX(rng.uniform(0, 1, (batch_size, rep_size)), name='init_H')
cur_V = V
cur_H = H
def update_V(cur_H):
return T.nnet.sigmoid(T.dot(cur_H, W.T))
def update_H(cur_V):
return T.nnet.sigmoid(T.dot(cur_V, W) + T.dot(G, W.T))
for i in xrange(num_rounds):
cur_V = update_V(cur_H)
cur_H = update_H(cur_V)
unrolled_theano = theano.function([], updates={V: cur_V, H: cur_H},
name='unrolled_theano')
nb_dot = sum([1 for node in unrolled_theano.maker.env.toposort()
if isinstance(node.op, (theano.tensor.Dot,
theano.tensor.blas.Dot22,
theano.tensor.blas.Gemm))])
assert nb_dot == num_rounds * 2 + 1, nb_dot
unrolled_theano()
def test_inplace0():
#should fail to insert gemm_inplace because gemm_inplace would
#create cycles
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论