提交 97174b84 authored 作者: carriepl's avatar carriepl

Merge pull request #3619 from nouiz/switch_opt

fix gh-3614, fix switch opt when both branch are the same and we broa…
......@@ -3471,7 +3471,9 @@ def local_useless_switch(node):
return [out]
# if left is right -> left
if node.inputs[1] is node.inputs[2]:
return [node.inputs[1]]
if cond.type == node.inputs[1].type:
return [node.inputs[1]]
return [T.fill(cond, node.inputs[1])]
# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
......
......@@ -4589,7 +4589,7 @@ class T_local_erfc(unittest.TestCase):
print(t1 - t0, t2 - t1)
class test_local_remove_switch_const_cond(unittest.TestCase):
class test_local_useless_switch(unittest.TestCase):
def setUp(self):
self.mode = mode_opt.excluding('constant_folding')
......@@ -4717,6 +4717,19 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
vy = numpy.array([[7, 8, 9], [10, 11, 12]], dtype='int64')
assert numpy.all(f(vx, vy) == vy)
def test_broadcast3(self):
# test switch(matrix, same_vector, same_vector)
x = theano.tensor.matrix('x', dtype='int32')
y = theano.tensor.vector('y', dtype='int64')
z = theano.tensor.switch(x, y, y)
f = theano.function([x, y], z, mode=self.mode)
vx = numpy.array([[0, 1], [1, 0]], dtype='int32')
vy = numpy.array([7, 8], dtype='int64')
utt.assert_allclose(f(vx, vy), numpy.where(vx, vy, vy))
assert len([node.op for node in f.maker.fgraph.toposort() if
isinstance(node.op, theano.tensor.Elemwise)]) == 0
class T_local_sum_prod(unittest.TestCase):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论