提交 b7c6c2ae authored 作者: Amjad Almahairi's avatar Amjad Almahairi

switch test left is right with var cond

上级 d3bee2da
...@@ -4600,10 +4600,13 @@ class test_local_remove_switch_const_cond(unittest.TestCase): ...@@ -4600,10 +4600,13 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
for dtype1 in ['int32', 'int64']: for dtype1 in ['int32', 'int64']:
x = theano.tensor.matrix('x', dtype=dtype1) x = theano.tensor.matrix('x', dtype=dtype1)
varc = theano.tensor.matrix('varc', dtype=dtype1)
z1 = theano.tensor.switch(1, x, x) z1 = theano.tensor.switch(1, x, x)
z0 = theano.tensor.switch(0, x, x) z0 = theano.tensor.switch(0, x, x)
z2 = theano.tensor.switch(varc, x, x)
f1 = theano.function([x], z1, mode=self.mode) f1 = theano.function([x], z1, mode=self.mode)
f0 = theano.function([x], z0, mode=self.mode) f0 = theano.function([x], z0, mode=self.mode)
f2 = theano.function([x,varc], z2, mode=self.mode)
topo = f1.maker.fgraph.toposort() topo = f1.maker.fgraph.toposort()
assert len(topo) == 1 assert len(topo) == 1
...@@ -4613,9 +4616,15 @@ class test_local_remove_switch_const_cond(unittest.TestCase): ...@@ -4613,9 +4616,15 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
assert len(topo) == 1 assert len(topo) == 1
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
topo = f2.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
vx = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) vx = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vc = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
assert numpy.all(f1(vx) == vx) assert numpy.all(f1(vx) == vx)
assert numpy.all(f0(vx) == vx) assert numpy.all(f0(vx) == vx)
assert numpy.all(f2(vx,vc) == vx)
def test_shape_le_0(self): def test_shape_le_0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论