提交 cff058c9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not apply `local_add_neg_to_sub` rewrite if negative variabe is a constant

上级 04ce1c6c
......@@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node):
@register_stabilize
@register_specialize
@register_canonicalize
@node_rewriter([sub])
@node_rewriter([add, sub])
def local_expm1(fgraph, node):
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
in1, in2 = node.inputs
out = node.outputs[0]
"""Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``."""
if len(node.inputs) != 2:
# TODO: handle more than two inputs in add
return None
if (
in1.owner
and isinstance(in1.owner.op, Elemwise)
and isinstance(in1.owner.op.scalar_op, ps.Exp)
and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1
):
in11 = in1.owner.inputs[0]
new_out = expm1(in11)
if isinstance(node.op.scalar_op, ps.Sub):
exp_x, other_inp = node.inputs
if not (
exp_x.owner
and isinstance(exp_x.owner.op, Elemwise)
and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
and get_underlying_scalar_constant_value(
other_inp, raise_not_constant=False
)
== 1
):
return None
else:
# Try both orders
other_inp, exp_x = node.inputs
for i in range(2):
if i == 1:
other_inp, exp_x = exp_x, other_inp
if (
exp_x.owner
and isinstance(exp_x.owner.op, Elemwise)
and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
and get_underlying_scalar_constant_value(
other_inp, raise_not_constant=False
)
== -1
):
break
else: # no break
return None
if new_out.type.broadcastable != out.type.broadcastable:
new_out = broadcast_arrays(in11, in2)[0]
[old_out] = node.outputs
if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype)
[x] = exp_x.owner.inputs
if x.type.broadcastable != old_out.type.broadcastable:
x = broadcast_arrays(x, other_inp)[0]
if not out.type.is_super(new_out.type):
return
return [new_out]
new_out = expm1(x)
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, dtype=old_out.dtype)
if not old_out.type.is_super(new_out.type):
return None
return [new_out]
@register_specialize
......@@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node):
new_out = sub(first, pre_neg)
return [new_out]
# Check if it is a negative constant
if (
isinstance(second, TensorConstant)
and second.unique_value is not None
and second.unique_value < 0
):
new_out = sub(first, np.abs(second.data))
return [new_out]
@register_canonicalize
@node_rewriter([mul])
......@@ -2606,9 +2626,9 @@ register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)
# erfc(-x)-1=>erf(x)
# -1 + erfc(-x)=>erf(x)
local_erf_neg_minus_one = PatternNodeRewriter(
(sub, (erfc, (neg, "x")), 1),
(add, -1, (erfc, (neg, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one",
......
......@@ -3806,14 +3806,9 @@ def test_local_expm1():
for n in h.maker.fgraph.toposort()
)
# This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked
expect_rewrite = config.mode != "FAST_COMPILE"
assert (
any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
for n in r.maker.fgraph.toposort()
)
== expect_rewrite
assert any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
for n in r.maker.fgraph.toposort()
)
......@@ -4440,25 +4435,6 @@ def test_local_add_neg_to_sub(first_negative):
assert np.allclose(f(x_test, y_test), exp)
@pytest.mark.parametrize("const_left", (True, False))
def test_local_add_neg_to_sub_const(const_left):
x = vector("x")
const = np.full((3, 2), 5.0)
out = -const + x if const_left else x + (-const)
f = function([x], out, mode=Mode("py"))
nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle | Alloc)
]
assert nodes == [pt.sub]
x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test + (-const))
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论