提交 12477108 authored 作者: Ricardo's avatar Ricardo 提交者: Rémi Louf

Add rewrite for subtraction with negation `x - (-y) -> x + y`

上级 77df6673
......@@ -1788,6 +1788,24 @@ def local_neg_div_neg(fgraph, node):
return [true_div(new_num, denom)]
@register_canonicalize
@register_specialize
@node_rewriter([sub])
def local_sub_neg_to_add(fgraph, node):
"""
x - (-y) -> x + y
"""
if node.op == sub:
minuend, subtrahend = node.inputs
if subtrahend.owner:
if subtrahend.owner.op == neg:
pre_neg = subtrahend.owner.inputs[0]
new_out = add(minuend, pre_neg)
return [new_out]
@register_canonicalize
@node_rewriter([mul])
def local_mul_zero(fgraph, node):
......
......@@ -4618,3 +4618,39 @@ def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
def test_local_sub_neg_to_add():
x = scalar("x")
y = vector("y")
f = function([x, y], x - (-y), mode=Mode("py"))
nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.add]
x_test = np.full((), 1.0, dtype=config.floatX)
y_test = np.full(5, 2.0, dtype=config.floatX)
assert np.allclose(f(x_test, y_test), x_test - (-y_test))
def test_local_sub_neg_to_add_const():
# This rewrite is achieved by the local_add_canonizer
x = vector("x")
const = 5.0
f = function([x], x - (-const), mode=Mode("py"))
nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.add]
x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test - (-const))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论