提交 c02d3283 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed the test for the optimization that removes switches with constant

condition.
上级 41946da9
......@@ -2284,7 +2284,9 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
z = theano.tensor.switch(0, x, y)
f = theano.function([x,y], z, mode=self.mode)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) ]) == 0
( isinstance(node.op,theano.tensor.Elemwise)
and isinstance(node.op.scalar_op,
theano.scalar.basic.Switch))]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype=dtype1)
vy = numpy.array([[7,8,9],[10,11,12]], dtype=dtype2)
assert numpy.all(f(vx,vy) == vy)
......@@ -2298,7 +2300,9 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
z = theano.tensor.switch(1, x, y)
f = theano.function([x,y], z, mode=self.mode)
assert len([node.op for node in f.maker.env.toposort() if
isinstance(node.op,theano.tensor.Elemwise) ]) == 0
( isinstance(node.op,theano.tensor.Elemwise)
and isinstance(node.op.scalar_op,
theano.scalar.basic.Switch))]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype=dtype1)
vy = numpy.array([[7,8,9],[10,11,12]], dtype=dtype2)
assert numpy.all(f(vx,vy) == vx)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论