提交 4f7d7096 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Simplify `local_[mul|div]_switch_sink` and fix downcasting bug

上级 4d0aa3f2
......@@ -621,46 +621,22 @@ def local_mul_switch_sink(fgraph, node):
part of the graph.
"""
for idx, i in enumerate(node.inputs):
if i.owner and i.owner.op == switch:
switch_node = i.owner
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
):
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
fmul = mul(*([*listmul, switch_node.inputs[2]]))
for mul_inp_idx, mul_inp in enumerate(node.inputs):
if mul_inp.owner and mul_inp.owner.op == switch:
switch_node = mul_inp.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
continue
# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
# inconsistent shapes), will point to the
# multiplication op.
copy_stack_trace(node.outputs, fmul)
switch_cond = switch_node.inputs[0]
other_switch_input = switch_node.inputs[1 + (1 - branch)]
fct = [switch(switch_node.inputs[0], 0, fmul)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
listmul = list(node.inputs)
listmul[mul_inp_idx] = other_switch_input
fmul = mul(*listmul)
# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
):
listmul = node.inputs[:idx] + node.inputs[idx + 1 :]
fmul = mul(*([*listmul, switch_node.inputs[1]]))
# Copy over stacktrace for elementwise multiplication op
# from previous elementwise multiplication op.
# An error in the multiplication (e.g. errors due to
......@@ -668,18 +644,20 @@ def local_mul_switch_sink(fgraph, node):
# multiplication op.
copy_stack_trace(node.outputs, fmul)
fct = [switch(switch_node.inputs[0], fmul, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
if branch == 0:
fct = switch(switch_cond, zero_switch_input, fmul)
else:
fct = switch(switch_cond, fmul, zero_switch_input)
# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise multiplication op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
return False
return [fct]
@register_canonicalize
......@@ -699,43 +677,21 @@ def local_div_switch_sink(fgraph, node):
See `local_mul_switch_sink` for more details.
"""
op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == switch:
switch_node = node.inputs[0].owner
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
# inconsistent shapes or division by zero),
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)
num, denom = node.inputs
fct = [switch(switch_node.inputs[0], 0, fdiv)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
if num.owner and num.owner.op == switch:
switch_node = num.owner
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
continue
switch_cond = switch_node.inputs[0]
other_switch_input = switch_node.inputs[1 + (1 - branch)]
fdiv = node.op(other_switch_input, denom)
# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
try:
if (
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
):
fdiv = op(switch_node.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
# An error in the division (e.g. errors due to
......@@ -743,18 +699,17 @@ def local_div_switch_sink(fgraph, node):
# will point to the new division op.
copy_stack_trace(node.outputs, fdiv)
fct = [switch(switch_node.inputs[0], fdiv, 0)]
fct[0].tag.values_eq_approx = values_eq_approx_remove_nan
fct = switch(switch_cond, zero_switch_input, fdiv)
# Tell debug_mode than the output is correct, even if nan disappear
fct.tag.values_eq_approx = values_eq_approx_remove_nan
# Copy over stacktrace for switch op from both previous
# elementwise division op and previous switch op,
# because an error in this part can be caused by either
# of the two previous ops.
copy_stack_trace(node.outputs + switch_node.outputs, fct)
return fct
except NotScalarConstantError:
pass
return False
return [fct]
class AlgebraicCanonizer(NodeRewriter):
......
......@@ -97,9 +97,11 @@ from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.math import (
compute_mul,
is_1pexp,
local_div_switch_sink,
local_grad_log_erfc_neg,
local_greedy_distributor,
local_mul_canonizer,
local_mul_switch_sink,
local_reduce_chain,
local_sum_prod_of_mul_or_div,
mul_canonizer,
......@@ -2115,7 +2117,6 @@ class TestLocalSwitchSink:
f = self.function_remove_nan([x], pytensor.gradient.grad(y, x), self.mode)
assert f(5) == 1, f(5)
@pytest.mark.slow
def test_local_div_switch_sink(self):
c = dscalar()
idx = 0
......@@ -2149,6 +2150,28 @@ class TestLocalSwitchSink:
].size
idx += 1
@pytest.mark.parametrize(
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
)
def test_local_mul_div_switch_sink_cast(self, op, rewrite):
"""Check that we don't downcast during the rewrite.
Regression test for: https://github.com/pymc-devs/pytensor/issues/1037
"""
cond = scalar("cond", dtype="bool")
# The zero branch upcasts the output, so we can't ignore its dtype
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
other_branch = scalar("other_branch", dtype="float32")
outer_var = scalar("mul_var", dtype="bool")
out = op(switch(cond, zero_branch, other_branch), outer_var)
fgraph = FunctionGraph(outputs=[out], clone=False)
[new_out] = rewrite.transform(fgraph, out.owner)
assert new_out.type.dtype == out.type.dtype
expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
assert equal_computations([new_out], [expected_out])
@pytest.mark.skipif(
config.cxx == "",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论