提交 cff058c9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not apply `local_add_neg_to_sub` rewrite if negative variabe is a constant

上级 04ce1c6c
...@@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node): ...@@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node):
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@node_rewriter([sub]) @node_rewriter([add, sub])
def local_expm1(fgraph, node): def local_expm1(fgraph, node):
"""Detect ``exp(a) - 1`` and convert them to ``expm1(a)``.""" """Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``."""
in1, in2 = node.inputs if len(node.inputs) != 2:
out = node.outputs[0] # TODO: handle more than two inputs in add
return None
if ( if isinstance(node.op.scalar_op, ps.Sub):
in1.owner exp_x, other_inp = node.inputs
and isinstance(in1.owner.op, Elemwise) if not (
and isinstance(in1.owner.op.scalar_op, ps.Exp) exp_x.owner
and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1 and isinstance(exp_x.owner.op, Elemwise)
): and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
in11 = in1.owner.inputs[0] and get_underlying_scalar_constant_value(
new_out = expm1(in11) other_inp, raise_not_constant=False
)
== 1
):
return None
else:
# Try both orders
other_inp, exp_x = node.inputs
for i in range(2):
if i == 1:
other_inp, exp_x = exp_x, other_inp
if (
exp_x.owner
and isinstance(exp_x.owner.op, Elemwise)
and isinstance(exp_x.owner.op.scalar_op, ps.Exp)
and get_underlying_scalar_constant_value(
other_inp, raise_not_constant=False
)
== -1
):
break
else: # no break
return None
if new_out.type.broadcastable != out.type.broadcastable: [old_out] = node.outputs
new_out = broadcast_arrays(in11, in2)[0]
if new_out.dtype != out.dtype: [x] = exp_x.owner.inputs
new_out = cast(new_out, dtype=out.dtype) if x.type.broadcastable != old_out.type.broadcastable:
x = broadcast_arrays(x, other_inp)[0]
if not out.type.is_super(new_out.type): new_out = expm1(x)
return
return [new_out] if new_out.dtype != old_out.dtype:
new_out = cast(new_out, dtype=old_out.dtype)
if not old_out.type.is_super(new_out.type):
return None
return [new_out]
@register_specialize @register_specialize
...@@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node): ...@@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node):
new_out = sub(first, pre_neg) new_out = sub(first, pre_neg)
return [new_out] return [new_out]
# Check if it is a negative constant
if (
isinstance(second, TensorConstant)
and second.unique_value is not None
and second.unique_value < 0
):
new_out = sub(first, np.abs(second.data))
return [new_out]
@register_canonicalize @register_canonicalize
@node_rewriter([mul]) @node_rewriter([mul])
...@@ -2606,9 +2626,9 @@ register_canonicalize(local_one_minus_erfc) ...@@ -2606,9 +2626,9 @@ register_canonicalize(local_one_minus_erfc)
register_stabilize(local_one_minus_erfc) register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc) register_specialize(local_one_minus_erfc)
# erfc(-x)-1=>erf(x) # -1 + erfc(-x)=>erf(x)
local_erf_neg_minus_one = PatternNodeRewriter( local_erf_neg_minus_one = PatternNodeRewriter(
(sub, (erfc, (neg, "x")), 1), (add, -1, (erfc, (neg, "x"))),
(erf, "x"), (erf, "x"),
allow_multiple_clients=True, allow_multiple_clients=True,
name="local_erf_neg_minus_one", name="local_erf_neg_minus_one",
......
...@@ -3806,14 +3806,9 @@ def test_local_expm1(): ...@@ -3806,14 +3806,9 @@ def test_local_expm1():
for n in h.maker.fgraph.toposort() for n in h.maker.fgraph.toposort()
) )
# This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked assert any(
expect_rewrite = config.mode != "FAST_COMPILE" isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
assert ( for n in r.maker.fgraph.toposort()
any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.basic.Expm1)
for n in r.maker.fgraph.toposort()
)
== expect_rewrite
) )
...@@ -4440,25 +4435,6 @@ def test_local_add_neg_to_sub(first_negative): ...@@ -4440,25 +4435,6 @@ def test_local_add_neg_to_sub(first_negative):
assert np.allclose(f(x_test, y_test), exp) assert np.allclose(f(x_test, y_test), exp)
@pytest.mark.parametrize("const_left", (True, False))
def test_local_add_neg_to_sub_const(const_left):
x = vector("x")
const = np.full((3, 2), 5.0)
out = -const + x if const_left else x + (-const)
f = function([x], out, mode=Mode("py"))
nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle | Alloc)
]
assert nodes == [pt.sub]
x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test + (-const))
def test_log1mexp_stabilization(): def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize") mode = Mode("py").including("stabilize")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论