提交 e24eb744 authored 作者: James Bergstra's avatar James Bergstra

fixed gemm_inplace tests

上级 616b7f60
import traceback
import theano.tensor as T
from ...gof import Env
from ...printing import pp
import numpy
from theano.tensor.blas import *
from theano.tensor.blas import _dot22, res_is_a
......@@ -391,19 +392,27 @@ def test_gemm_opt_vector_stuff():
def test_inplace0():
#should fail to insert gemm because gemm would create cycles
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R, S, c = T.dmatrix('R'), T.dmatrix('S'), T.dscalar('c')
f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN')
[Z * (Z + b * T.dot(R,S).T)], mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
print pp(f.maker.env.outputs[0])
raise Failure('gemm in graph')
f = inplace_func([X,Y,Z,a,b, R, S, c],
[Z * (c*Z + a * T.dot(X,Y) + b * T.dot(R,S).T)], mode='FAST_RUN')
# gemm should be insertedd here, to work in-place on Z*c
if (not gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm in graph')
def test_inplace1():
X,Y,Z,a,b = XYZab()
# with > 2 terms in the overall addition
f = inplace_func([X,Y,Z,a,b],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN')
if (gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('gemm in graph')
# gemm should operate in-place on (Z+Z)
if (not gemm in [n.op for n in f.maker.env.nodes]):
raise Failure('no gemm in graph')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论