提交 e36afe03 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

clean up slop

上级 f82de085
......@@ -737,10 +737,6 @@ def local_div_exp_to_mul_exp(fgraph, node):
Multiplication is generally cheaper than division and the resulting
``exp(-B)`` may fuse with other exponentials via
``local_mul_exp_to_exp_add``.
We skip the case where the numerator is also ``exp(...)`` because
``local_mul_exp_to_exp_add`` already handles ``exp(A) / exp(B) → exp(A-B)``
directly on the ``true_div`` node.
"""
num, denom = node.inputs
......@@ -751,14 +747,6 @@ def local_div_exp_to_mul_exp(fgraph, node):
):
return None
# Skip if numerator is also exp — local_mul_exp_to_exp_add handles that
if (
num.owner
and isinstance(num.owner.op, Elemwise)
and isinstance(num.owner.op.scalar_op, ps.Exp)
):
return None
exp_arg = denom.owner.inputs[0]
new_out = num * exp(neg(exp_arg))
if new_out.dtype != node.outputs[0].dtype:
......
......@@ -3814,98 +3814,21 @@ def test_local_mul_exp_to_exp_add():
def test_local_div_exp_to_mul_exp():
mode = get_default_mode().excluding("fusion").including("local_div_exp_to_mul_exp")
x = scalar("x")
y = scalar("y")
# y / exp(x) -> y * exp(-x)
out = y / exp(x)
rewritten = rewrite_graph(out, include=["specialize"], exclude=["fusion"])
expected = y * exp(neg(x))
assert_equal_computations([rewritten], [expected])
# Also verify numerically
f = function([x, y], out, mode)
utt.assert_allclose(f(2.0, 3.0), 3.0 * np.exp(-2.0))
graph = f.maker.fgraph.toposort()
assert not any(
isinstance(n.op.scalar_op, ps.TrueDiv)
for n in graph
if isinstance(n.op, Elemwise)
)
# Matrices
mx = matrix("mx")
my = matrix("my")
f = function([mx, my], my / exp(mx), mode, allow_input_downcast=True)
M1 = np.array([[1.0, 2.0], [3.0, 4.0]])
M2 = np.array([[5.0, 6.0], [7.0, 8.0]])
utt.assert_allclose(f(M1, M2), M2 * np.exp(-M1))
graph = f.maker.fgraph.toposort()
assert not any(
isinstance(n.op.scalar_op, ps.TrueDiv)
for n in graph
if isinstance(n.op, Elemwise)
)
# exp(x) / exp(y) should NOT be affected (handled by local_mul_exp_to_exp_add)
# With only our rewrite enabled, the division should remain untouched
only_our_mode = (
get_default_mode()
.excluding("fusion", "specialize", "stabilize", "canonicalize")
.including("local_div_exp_to_mul_exp")
)
out = exp(x) / exp(y)
f = function([x, y], out, only_our_mode)
graph = f.maker.fgraph.toposort()
# The div should remain because our rewrite skips exp/exp
assert any(
isinstance(n.op.scalar_op, ps.TrueDiv)
for n in graph
if isinstance(n.op, Elemwise)
)
rewritten = rewrite_graph(out, include=("specialize",))
assert_equal_computations([rewritten], [expected])
# 1 / exp(x) -> exp(-x)
out = true_div(np.float64(1.0), exp(x))
f = function([x], out, mode)
utt.assert_allclose(f(3.0), np.exp(-3.0))
graph = f.maker.fgraph.toposort()
assert not any(
isinstance(n.op.scalar_op, ps.TrueDiv)
for n in graph
if isinstance(n.op, Elemwise)
)
# Sigmoid pattern exp(x) / (1 + exp(x)) must still work
sigmoid_mode = (
get_default_mode()
.excluding("fusion")
.including("local_div_exp_to_mul_exp", "stabilize")
)
out = exp(x) / (1 + exp(x))
f = function([x], out, sigmoid_mode)
utt.assert_allclose(f(0.0), 0.5)
utt.assert_allclose(f(-50.0), 1.0 / (1.0 + np.exp(50.0)), atol=1e-7)
# The sigmoid rewrite should have fired — the denominator (1+exp(x)) is
# NOT bare exp(), so local_div_exp_to_mul_exp should not interfere.
graph = f.maker.fgraph.toposort()
assert any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, ps.Sigmoid)
for n in graph
)
# Chain: y / exp(x) with further exp fusion
# y / exp(x) * exp(z) -> y * exp(-x) * exp(z) -> y * exp(z - x)
z = scalar("z")
chain_mode = (
get_default_mode()
.excluding("fusion")
.including("local_div_exp_to_mul_exp", "local_mul_exp_to_exp_add")
)
out = y / exp(x) * exp(z)
f = function([x, y, z], out, chain_mode)
utt.assert_allclose(f(2.0, 3.0, 5.0), 3.0 * np.exp(5.0 - 2.0))
expected = exp(neg(x))
rewritten = rewrite_graph(out, include=("specialize",))
assert_equal_computations([rewritten], [expected])
def test_local_mul_pow_to_pow_add():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论