提交 bbf937c3 authored 作者: Ricardo's avatar Ricardo 提交者: Rémi Louf

Add rewrite for addition with negation `x + (-y) -> x - y`

* Also reverses test expectation for `local_expm1` rewrite, as previously missed case is now detected
上级 12477108
......@@ -1806,6 +1806,35 @@ def local_sub_neg_to_add(fgraph, node):
return [new_out]
@register_specialize
@node_rewriter([add])
def local_add_neg_to_sub(fgraph, node):
"""
-x + y -> y - x
x + (-y) -> x - y
"""
# This rewrite is only registered during specialization, because the
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization
# Rewrite is only applicable when there are two inputs to add
if node.op == add and len(node.inputs) == 2:
# Look for pattern with either input order
for first, second in (node.inputs, reversed(node.inputs)):
if second.owner:
if second.owner.op == neg:
pre_neg = second.owner.inputs[0]
new_out = sub(first, pre_neg)
return [new_out]
# Check if it is a negative constant
const = get_constant(second)
if const is not None and const < 0:
new_out = sub(first, np.abs(const))
return [new_out]
@register_canonicalize
@node_rewriter([mul])
def local_mul_zero(fgraph, node):
......
......@@ -4042,9 +4042,14 @@ def test_local_expm1():
for n in h.maker.fgraph.toposort()
)
assert not any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, aes.basic.Expm1)
for n in r.maker.fgraph.toposort()
# This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked
expect_rewrite = config.mode != "FAST_COMPILE"
assert (
any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, aes.basic.Expm1)
for n in r.maker.fgraph.toposort()
)
== expect_rewrite
)
......@@ -4654,3 +4659,41 @@ def test_local_sub_neg_to_add_const():
x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test - (-const))
@pytest.mark.parametrize("first_negative", (True, False))
def test_local_add_neg_to_sub(first_negative):
x = scalar("x")
y = vector("y")
out = -x + y if first_negative else x + (-y)
f = function([x, y], out, mode=Mode("py"))
nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.sub]
x_test = np.full((), 1.0, dtype=config.floatX)
y_test = np.full(5, 2.0, dtype=config.floatX)
exp = -x_test + y_test if first_negative else x_test + (-y_test)
assert np.allclose(f(x_test, y_test), exp)
def test_local_add_neg_to_sub_const():
x = vector("x")
const = 5.0
f = function([x], x + (-const), mode=Mode("py"))
nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.sub]
x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test + (-const))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论