Unverified 提交 e37497fd authored 作者: Tamás Tőkés's avatar Tamás Tőkés 提交者: GitHub

Rewrite products of exponents as exponent of sum (#186)

* Rewrite products of exponents as exponent of sum. Rewrite e^x*e^y to e^(x+y), e^x/e^y to e^(x-y). * Rewrite a^x * a^y to a^(x+y)
上级 5628ab15
...@@ -2,6 +2,7 @@ r"""Rewrites for the `Op`\s in :mod:`pytensor.tensor.math`.""" ...@@ -2,6 +2,7 @@ r"""Rewrites for the `Op`\s in :mod:`pytensor.tensor.math`."""
import itertools import itertools
import operator import operator
from collections import defaultdict
from functools import partial, reduce from functools import partial, reduce
import numpy as np import numpy as np
...@@ -423,6 +424,100 @@ def local_sumsqr2dot(fgraph, node): ...@@ -423,6 +424,100 @@ def local_sumsqr2dot(fgraph, node):
return [new_out] return [new_out]
@register_specialize
@node_rewriter([mul, true_div])
def local_mul_exp_to_exp_add(fgraph, node):
"""
This rewrite detects e^x * e^y and converts it to e^(x+y).
Similarly, e^x / e^y becomes e^(x-y).
"""
exps = [
n.owner.inputs[0]
for n in node.inputs
if n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op.scalar_op, aes.Exp)
]
# Can only do any rewrite if there are at least two exp-s
if len(exps) >= 2:
# Mul -> add; TrueDiv -> sub
orig_op, new_op = mul, add
if isinstance(node.op.scalar_op, aes.TrueDiv):
orig_op, new_op = true_div, sub
new_out = exp(new_op(*exps))
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
# The original Mul may have more than two factors, some of which may not be exp nodes.
# If so, we keep multiplying them with the new exp(sum) node.
# E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
rest = [
n
for n in node.inputs
if not n.owner
or not hasattr(n.owner.op, "scalar_op")
or not isinstance(n.owner.op.scalar_op, aes.Exp)
]
if len(rest) > 0:
new_out = orig_op(new_out, *rest)
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
return [new_out]
@register_specialize
@node_rewriter([mul, true_div])
def local_mul_pow_to_pow_add(fgraph, node):
"""
This rewrite detects a^x * a^y and converts it to a^(x+y).
Similarly, a^x / a^y becomes a^(x-y).
"""
# search for pow-s and group them by their bases
pow_nodes = defaultdict(list)
rest = []
for n in node.inputs:
if (
n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op.scalar_op, aes.Pow)
):
base_node = n.owner.inputs[0]
# exponent is at n.owner.inputs[1], but we need to store the full node
# in case this particular power node remains alone and can't be rewritten
pow_nodes[base_node].append(n)
else:
rest.append(n)
# Can only do any rewrite if there are at least two pow-s with the same base
can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2]
if len(can_rewrite) >= 1:
# Mul -> add; TrueDiv -> sub
orig_op, new_op = mul, add
if isinstance(node.op.scalar_op, aes.TrueDiv):
orig_op, new_op = true_div, sub
pow_factors = []
# Rewrite pow-s having the same base for each different base
# E.g.: a^x * a^y --> a^(x+y)
for base in can_rewrite:
exponents = [n.owner.inputs[1] for n in pow_nodes[base]]
new_node = base ** new_op(*exponents)
if new_node.dtype != node.outputs[0].dtype:
new_node = cast(new_node, dtype=node.outputs[0].dtype)
pow_factors.append(new_node)
# Don't forget about those sole pow-s that couldn't be rewriten
sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite]
# Combine the rewritten pow-s and other, non-pow factors of the original Mul
# E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0:
new_out = orig_op(*pow_factors, *sole_pows, *rest)
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
else:
# if all factors of the original mul were pows-s with the same base,
# we can get rid of the mul completely.
new_out = pow_factors[0]
return [new_out]
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
......
...@@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot(): ...@@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot():
) )
def test_local_mul_exp_to_exp_add():
# Default and FAST_RUN modes put a Composite op into the final graph,
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
mode = get_default_mode().excluding("fusion").including("local_mul_exp_to_exp_add")
x = scalar("x")
y = scalar("y")
z = scalar("z")
w = scalar("w")
expx = exp(x)
expy = exp(y)
expz = exp(z)
expw = exp(w)
# e^x * e^y * e^z * e^w = e^(x+y+z+w)
op = expx * expy * expz * expw
f = function([x, y, z, w], op, mode)
pytensor.dprint(f)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# e^x * e^y * e^z / e^w = e^(x+y+z-w)
op = expx * expy * expz / expw
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
# e^x * e^y / e^z * e^w = e^(x+y-z+w)
op = expx * expy / expz * expw
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
# e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
op = expx / expy / expz
f = function([x, y, z], op, mode)
utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
# e^x * y * e^z * w = e^(x+z) * y * w
op = expx * y * expz * w
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6)
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# expect same for matrices as well
mx = matrix("mx")
my = matrix("my")
f = function([mx, my], exp(mx) * exp(my), mode, allow_input_downcast=True)
M1 = np.array([[1.0, 2.0], [3.0, 4.0]])
M2 = np.array([[5.0, 6.0], [7.0, 8.0]])
utt.assert_allclose(f(M1, M2), np.exp(M1 + M2))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# checking whether further rewrites can proceed after this one as one would expect
# e^x * e^(-x) = e^(x-x) = e^0 = 1
f = function([x], expx * exp(neg(x)), mode)
utt.assert_allclose(f(42), 1)
graph = f.maker.fgraph.toposort()
assert isinstance(graph[0].inputs[0], TensorConstant)
# e^x / e^x = e^(x-x) = e^0 = 1
f = function([x], expx / expx, mode)
utt.assert_allclose(f(42), 1)
graph = f.maker.fgraph.toposort()
assert isinstance(graph[0].inputs[0], TensorConstant)
def test_local_mul_pow_to_pow_add():
# Default and FAST_RUN modes put a Composite op into the final graph,
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
mode = (
get_default_mode()
.excluding("fusion")
.including("local_mul_exp_to_exp_add")
.including("local_mul_pow_to_pow_add")
)
x = scalar("x")
y = scalar("y")
z = scalar("z")
w = scalar("w")
v = scalar("v")
u = scalar("u")
t = scalar("t")
s = scalar("s")
a = scalar("a")
b = scalar("b")
c = scalar("c")
# 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
op = 2**x * 2**y * 2**z * 2**w
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t
f = function([x, y, z, w, v, u, t, s, a, b, c], op, mode)
utt.assert_allclose(
f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5),
2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11,
)
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 3
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Pow)]) == 4
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
op = 2**x / 2**y * (a**z / a**w)
f = function([x, y, z, w, a], op, mode)
utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Sub)]) == 2
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
op = a**x * a**y * exp(z) * exp(w)
f = function([x, y, z, w, a], op, mode)
utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 2
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
def test_local_expm1(): def test_local_expm1():
x = matrix("x") x = matrix("x")
u = scalar("u") u = scalar("u")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论