提交 b7c6c2ae authored 作者: Amjad Almahairi's avatar Amjad Almahairi

switch test left is right with var cond

上级 d3bee2da
......@@ -4600,10 +4600,13 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
for dtype1 in ['int32', 'int64']:
x = theano.tensor.matrix('x', dtype=dtype1)
varc = theano.tensor.matrix('varc', dtype=dtype1)
z1 = theano.tensor.switch(1, x, x)
z0 = theano.tensor.switch(0, x, x)
z2 = theano.tensor.switch(varc, x, x)
f1 = theano.function([x], z1, mode=self.mode)
f0 = theano.function([x], z0, mode=self.mode)
f2 = theano.function([x,varc], z2, mode=self.mode)
topo = f1.maker.fgraph.toposort()
assert len(topo) == 1
......@@ -4613,9 +4616,15 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
assert len(topo) == 1
assert topo[0].op == deep_copy_op
topo = f2.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
vx = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vc = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
assert numpy.all(f1(vx) == vx)
assert numpy.all(f0(vx) == vx)
assert numpy.all(f2(vx,vc) == vx)
def test_shape_le_0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论