提交 4f7d7096 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Simplify `local_[mul|div]_switch_sink` and fix downcasting bug

上级 4d0aa3f2
...@@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node): ...@@ -621,65 +621,43 @@ def local_mul_switch_sink(fgraph, node):
part of the graph. part of the graph.
""" """
for idx, i in enumerate(node.inputs): for mul_inp_idx, mul_inp in enumerate(node.inputs):
if i.owner and i.owner.op == switch: if mul_inp.owner and mul_inp.owner.op == switch:
switch_node = i.owner switch_node = mul_inp.owner
try: # Look for a zero as the first or second branch of the switch
if ( for branch in range(2):
get_underlying_scalar_constant_value( zero_switch_input = switch_node.inputs[1 + branch]
switch_node.inputs[1], only_process_constants=True if not get_unique_constant_value(zero_switch_input) == 0.0:
) continue
== 0.0
): switch_cond = switch_node.inputs[0]
listmul = node.inputs[:idx] + node.inputs[idx + 1 :] other_switch_input = switch_node.inputs[1 + (1 - branch)]
fmul = mul(*([*listmul, switch_node.inputs[2]]))
listmul = list(node.inputs)
# Copy over stacktrace for elementwise multiplication op listmul[mul_inp_idx] = other_switch_input
# from previous elementwise multiplication op. fmul = mul(*listmul)
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the # Copy over stacktrace for elementwise multiplication op
# multiplication op. # from previous elementwise multiplication op.
copy_stack_trace(node.outputs, fmul) # An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
fct = [switch(switch_node.inputs[0], 0, fmul)] # multiplication op.
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan copy_stack_trace(node.outputs, fmul)
# Copy over stacktrace for switch op from both previous if branch == 0:
# elementwise multiplication op and previous switch op, fct = switch(switch_cond, zero_switch_input, fmul)
# because an error in this part can be caused by either else:
# of the two previous ops. fct = switch(switch_cond, fmul, zero_switch_input)
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct # Tell debug_mode than the output is correct, even if nan disappear
except NotScalarConstantError: fct.tag.values_eq_approx = values_eq_approx_remove_nan
pass
try: # Copy over stacktrace for switch op from both previous
if ( # elementwise multiplication op and previous switch op,
get_underlying_scalar_constant_value( # because an error in this part can be caused by either
switch_node.inputs[2], only_process_constants=True # of the two previous ops.
) copy_stack_trace(node.outputs + switch_node.outputs, fct)
== 0.0 return [fct]
):
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
fmul = mul(*([*listmul, switch_node.inputs[1]]))
# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)
fct = [switch(switch_node.inputs[0], fmul, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
return False
@register_canonicalize @register_canonicalize
...@@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node): ...@@ -699,62 +677,39 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details. See `local_mul_switch_sink` for more details.
""" """
op = node.op num, denom = node.inputs
if node.inputs[0].owner and node.inputs[0].owner.op == switch:
switch_node = node.inputs[0].owner
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)
fct = [switch(switch_node.inputs[0], 0, fdiv)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)
fct = [switch(switch_node.inputs[0], fdiv, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous if num.owner and num.owner.op == switch:
# elementwise division op and previous switch op, switch_node = num.owner
# because an error in this part can be caused by either # Look for a zero as the first or second branch of the switch
# of the two previous ops. for branch in range(2):
copy_stack_trace(node.outputs + switch_node.outputs, fct) zero_switch_input = switch_node.inputs[1 + branch]
return fct if not get_unique_constant_value(zero_switch_input) == 0.0:
except NotScalarConstantError: continue
pass
return False switch_cond = switch_node.inputs[0]
other_switch_input = switch_node.inputs[1 + (1 - branch)]
fdiv = node.op(other_switch_input, denom)
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)
fct = switch(switch_cond, zero_switch_input, fdiv)
# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return [fct]
class AlgebraicCanonizer(NodeRewriter): class AlgebraicCanonizer(NodeRewriter):
......
...@@ -97,9 +97,11 @@ from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift ...@@ -97,9 +97,11 @@ from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.math import ( from pytensor.tensor.rewriting.math import (
compute_mul, compute_mul,
is_1pexp, is_1pexp,
local_div_switch_sink,
local_grad_log_erfc_neg, local_grad_log_erfc_neg,
local_greedy_distributor, local_greedy_distributor,
local_mul_canonizer, local_mul_canonizer,
local_mul_switch_sink,
local_reduce_chain, local_reduce_chain,
local_sum_prod_of_mul_or_div, local_sum_prod_of_mul_or_div,
mul_canonizer, mul_canonizer,
...@@ -2115,7 +2117,6 @@ class TestLocalSwitchSink: ...@@ -2115,7 +2117,6 @@ class TestLocalSwitchSink:
f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode) f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode)
assert f(5) == 1, f(5) assert f(5) == 1, f(5)
@pytest.mark.slow
def test_local_div_switch_sink(self): def test_local_div_switch_sink(self):
c = dscalar() c = dscalar()
idx = 0 idx = 0
...@@ -2149,6 +2150,28 @@ class TestLocalSwitchSink: ...@@ -2149,6 +2150,28 @@ class TestLocalSwitchSink:
].size ].size
idx += 1 idx += 1
@pytest.mark.parametrize(
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
)
def test_local_mul_div_switch_sink_cast(self, op, rewrite):
"""Check that we don't downcast during the rewrite.
Regression test for: https://github.com/pymc-devs/pytensor/issues/1037
"""
cond = scalar("cond", dtype="bool")
# The zero branch upcasts the output, so we can't ignore its dtype
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
other_branch = scalar("other_branch", dtype="float32")
outer_var = scalar("mul_var", dtype="bool")
out = op(switch(cond, zero_branch, other_branch), outer_var)
fgraph = FunctionGraph(outputs=[out], clone=False)
[new_out] = rewrite.transform(fgraph, out.owner)
assert new_out.type.dtype == out.type.dtype
expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
assert equal_computations([new_out], [expected_out])
@pytest.mark.skipif( @pytest.mark.skipif(
config.cxx == "", config.cxx == "",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论