提交 fdbf3aa5 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix bug in local_div_switch_sink rewrite

Introduced in 4f7d7096
上级 c2e88c6e
...@@ -699,7 +699,10 @@ def local_div_switch_sink(fgraph, node): ...@@ -699,7 +699,10 @@ def local_div_switch_sink(fgraph, node):
# will point to the new division op. # will point to the new division op.
copy_stack_trace(node.outputs, fdiv) copy_stack_trace(node.outputs, fdiv)
fct = switch(switch_cond, zero_switch_input, fdiv) if branch == 0:
fct = switch(switch_cond, zero_switch_input, fdiv)
else:
fct = switch(switch_cond, fdiv, zero_switch_input)
# Tell debug_mode than the output is correct, even if nan disappear # Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan fct.tag.values_eq_approx = values_eq_approx_remove_nan
......
...@@ -2163,7 +2163,7 @@ class TestLocalSwitchSink: ...@@ -2163,7 +2163,7 @@ class TestLocalSwitchSink:
# The zero branch upcasts the output, so we can't ignore its dtype # The zero branch upcasts the output, so we can't ignore its dtype
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch") zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
other_branch = scalar("other_branch", dtype="float32") other_branch = scalar("other_branch", dtype="float32")
outer_var = scalar("mul_var", dtype="bool") outer_var = scalar("outer_var", dtype="bool")
out = op(switch(cond, zero_branch, other_branch), outer_var) out = op(switch(cond, zero_branch, other_branch), outer_var)
fgraph = FunctionGraph(outputs=[out], clone=False) fgraph = FunctionGraph(outputs=[out], clone=False)
...@@ -2173,6 +2173,27 @@ class TestLocalSwitchSink: ...@@ -2173,6 +2173,27 @@ class TestLocalSwitchSink:
expected_out = switch(cond, zero_branch, op(other_branch, outer_var)) expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
assert equal_computations([new_out], [expected_out]) assert equal_computations([new_out], [expected_out])
@pytest.mark.parametrize(
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
)
def test_local_mul_div_switch_sink_branch_order(self, op, rewrite):
cond = scalar("cond", dtype="bool")
zero_branch = constant(np.array(0.0, dtype="float64"), "zero_branch")
other_branch = scalar("other_branch", dtype="float64")
outer_var = scalar("outer_var", dtype="float64")
left = op(switch(cond, zero_branch, other_branch), outer_var)
right = op(switch(cond, other_branch, zero_branch), outer_var)
fgraph = FunctionGraph(outputs=[left, right], clone=False)
[new_left] = rewrite.transform(fgraph, left.owner)
[new_right] = rewrite.transform(fgraph, right.owner)
expected_left = switch(cond, zero_branch, op(other_branch, outer_var))
expected_right = switch(cond, op(other_branch, outer_var), zero_branch)
assert equal_computations(
[new_left, new_right], [expected_left, expected_right]
)
@pytest.mark.skipif( @pytest.mark.skipif(
config.cxx == "", config.cxx == "",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论