提交 ef22377d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug when broadcasting branches in local_useless_switch rewrite

上级 5a47550f
......@@ -1023,18 +1023,15 @@ def local_useless_switch(fgraph, node):
# if left is right -> left
if equivalent_up_to_constant_casting(left, right):
if left.type.broadcastable == out_bcast:
if left.type.broadcastable != out_bcast:
left, _ = broadcast_arrays(left, cond)
out_dtype = node.outputs[0].type.dtype
if left.type.dtype != out_dtype:
left = cast(left, out_dtype)
copy_stack_trace(node.outputs + left, left)
# When not casting, the other inputs of the switch aren't needed in the traceback
return [left]
else:
ret = broadcast_arrays(left, cond)[0]
copy_stack_trace(node.outputs + left, ret)
return [ret]
copy_stack_trace(node.outputs + node.inputs, left)
return [left]
# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
......
......@@ -1089,6 +1089,25 @@ class TestLocalUselessSwitch:
assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc)
assert not any(node.op == pt.switch for node in f.maker.fgraph.toposort())
def test_broadcasting_different_dtype(self):
cond = vector("x", dtype="bool")
float32_branch = as_tensor(np.array([0], dtype="float32"))
float64_branch = as_tensor(np.array([0], dtype="float64"))
out = pt.switch(cond, float32_branch, float64_branch)
expected_out = pt.alloc(float64_branch, cond.shape)
rewritten_out = rewrite_graph(
out, include=("canonicalize", "stabilize", "specialize")
)
assert equal_computations([rewritten_out], [expected_out])
out = pt.switch(cond, float64_branch, float32_branch)
rewritten_out = rewrite_graph(
out, include=("canonicalize", "stabilize", "specialize")
)
assert equal_computations([rewritten_out], [expected_out])
class TestLocalMergeSwitchSameCond:
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论