提交 92956aec authored 作者: Samira Shabanian's avatar Samira Shabanian

changed f_test

上级 bce7c0a2
......@@ -3593,7 +3593,7 @@ def local_sumsqr2dot(node):
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Sqr)):
in_sqr = in1.owner.inputs[0]
if (in_sqr.owner and isinstance(in_sqr.owner.op, T.Elemwise) and
isinstance(in_sqr.owner.op.scalar_op, theano.scalar.basic.Mul)):
isinstance(in_sqr.owner.op.scalar_op, theano.scalar.basic.Mul) and len(in_sqr.owner.inputs) == 2):
in_mul1, in_mul2 = in_sqr.owner.inputs
if (isinstance(in_mul1.owner.op, T.elemwise.DimShuffle) and in_mul1.owner.op.new_order == ('x', 0, 1) and
......
......@@ -6177,7 +6177,7 @@ def test_local_sumsqr2dot():
G = matrix('G')
W = matrix('W')
y = T.sqr( W.dimshuffle('x',0,1) * G.dimshuffle(0,'x',1) ).sum(axis=(1,2))
y = T.sqr(W.dimshuffle('x', 0, 1) * G.dimshuffle(0, 'x', 1)).sum(axis=(1, 2))
MODE = theano.compile.get_default_mode().including('local_sumsqr2dot')
f = function([W, G], y, mode=MODE)
......@@ -6186,9 +6186,9 @@ def test_local_sumsqr2dot():
g_val = numpy.random.rand(5, 3).astype(config.floatX)
f_val = f(w_val, g_val)
f_test = function([W,G], T.dot(T.sqr(G), T.sqr(W).sum(axis=0)), mode=MODE)
f_test = numpy.dot(numpy.square(g_val), numpy.square(w_val).sum(axis=0))
assert numpy.allclose(f_val, f_test(w_val, g_val))
assert numpy.allclose(f_val, f_test)
assert any(isinstance(n.op, (tensor.basic.Dot, tensor.blas.Dot22,
tensor.blas.Gemv, tensor.blas_c.CGemv))
for n in f.maker.fgraph.toposort())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论