提交 975ca888 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

local_[div|mul]_switch_sink: Allow zero inside a DimShuffle

上级 a314476f
......@@ -840,6 +840,7 @@ def local_expm1(fgraph, node):
@register_specialize
@register_stabilize
@register_canonicalize
@node_rewriter([mul])
def local_mul_switch_sink(fgraph, node):
......@@ -878,14 +879,17 @@ def local_mul_switch_sink(fgraph, node):
switch_node = mul_inp.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if (
not get_underlying_scalar_constant_value(
zero_switch_input,
only_process_constants=True,
raise_not_constant=False,
)
== 0.0
zero_inp = underlying_zero = switch_node.inputs[1 + branch]
# Allow zero inside a DimShuffle or Alloc
if zero_inp.owner is not None and isinstance(
zero_inp.owner.op, DimShuffle | Alloc
):
underlying_zero = zero_inp.owner.inputs[0]
if not (
isinstance(underlying_zero, TensorConstant)
and underlying_zero.unique_value == 0
):
continue
......@@ -904,9 +908,9 @@ def local_mul_switch_sink(fgraph, node):
copy_stack_trace(node.outputs, fmul)
if branch == 0:
fct = switch(switch_cond, zero_switch_input, fmul)
fct = switch(switch_cond, zero_inp, fmul)
else:
fct = switch(switch_cond, fmul, zero_switch_input)
fct = switch(switch_cond, fmul, zero_inp)
# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan
......@@ -942,14 +946,17 @@ def local_div_switch_sink(fgraph, node):
switch_node = num.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if (
not get_underlying_scalar_constant_value(
zero_switch_input,
only_process_constants=True,
raise_not_constant=False,
)
== 0.0
zero_inp = underlying_zero = switch_node.inputs[1 + branch]
# Allow zero inside a DimShuffle or Alloc
if zero_inp.owner is not None and isinstance(
zero_inp.owner.op, DimShuffle | Alloc
):
underlying_zero = zero_inp.owner.inputs[0]
if not (
isinstance(underlying_zero, TensorConstant)
and underlying_zero.unique_value == 0
):
continue
......@@ -966,9 +973,9 @@ def local_div_switch_sink(fgraph, node):
copy_stack_trace(node.outputs, fdiv)
if branch == 0:
fct = switch(switch_cond, zero_switch_input, fdiv)
fct = switch(switch_cond, zero_inp, fdiv)
else:
fct = switch(switch_cond, fdiv, zero_switch_input)
fct = switch(switch_cond, fdiv, zero_inp)
# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan
......
......@@ -2278,6 +2278,27 @@ class TestLocalSwitchSink:
[new_left, new_right], [expected_left, expected_right]
)
def test_safe_constant_fold(self):
inner_mul = [1.5, 0, 1]
outer_mul = [1.5, -np.inf, -np.inf]
valid = [True, False, True]
for inner_zero in (
pt.expand_dims(np.array(0), 0),
pt.zeros((3,)),
):
out = (
switch(
valid,
inner_mul,
inner_zero,
)
* outer_mul
)
# If the rewrite doesn't happen before constant_folding, the middle term will be nan
np.testing.assert_allclose(out.eval(), [1.5**2, 0, -np.inf])
@pytest.mark.skipif(
config.cxx == "",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论