Unverified 提交 e37497fd authored 作者: Tamás Tőkés's avatar Tamás Tőkés 提交者: GitHub

Rewrite products of exponents as exponent of sum (#186)

* Rewrite products of exponents as exponent of sum. Rewrite e^x*e^y to e^(x+y), e^x/e^y to e^(x-y). * Rewrite a^x * a^y to a^(x+y)
上级 5628ab15
......@@ -2,6 +2,7 @@ r"""Rewrites for the `Op`\s in :mod:`pytensor.tensor.math`."""
import itertools
import operator
from collections import defaultdict
from functools import partial, reduce
import numpy as np
......@@ -423,6 +424,100 @@ def local_sumsqr2dot(fgraph, node):
return [new_out]
@register_specialize
@node_rewriter([mul, true_div])
def local_mul_exp_to_exp_add(fgraph, node):
"""
This rewrite detects e^x * e^y and converts it to e^(x+y).
Similarly, e^x / e^y becomes e^(x-y).
"""
exps = [
n.owner.inputs[0]
for n in node.inputs
if n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op.scalar_op, aes.Exp)
]
# Can only do any rewrite if there are at least two exp-s
if len(exps) >= 2:
# Mul -> add; TrueDiv -> sub
orig_op, new_op = mul, add
if isinstance(node.op.scalar_op, aes.TrueDiv):
orig_op, new_op = true_div, sub
new_out = exp(new_op(*exps))
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
# The original Mul may have more than two factors, some of which may not be exp nodes.
# If so, we keep multiplying them with the new exp(sum) node.
# E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
rest = [
n
for n in node.inputs
if not n.owner
or not hasattr(n.owner.op, "scalar_op")
or not isinstance(n.owner.op.scalar_op, aes.Exp)
]
if len(rest) > 0:
new_out = orig_op(new_out, *rest)
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
return [new_out]
@register_specialize
@node_rewriter([mul, true_div])
def local_mul_pow_to_pow_add(fgraph, node):
"""
This rewrite detects a^x * a^y and converts it to a^(x+y).
Similarly, a^x / a^y becomes a^(x-y).
"""
# search for pow-s and group them by their bases
pow_nodes = defaultdict(list)
rest = []
for n in node.inputs:
if (
n.owner
and hasattr(n.owner.op, "scalar_op")
and isinstance(n.owner.op.scalar_op, aes.Pow)
):
base_node = n.owner.inputs[0]
# exponent is at n.owner.inputs[1], but we need to store the full node
# in case this particular power node remains alone and can't be rewritten
pow_nodes[base_node].append(n)
else:
rest.append(n)
# Can only do any rewrite if there are at least two pow-s with the same base
can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2]
if len(can_rewrite) >= 1:
# Mul -> add; TrueDiv -> sub
orig_op, new_op = mul, add
if isinstance(node.op.scalar_op, aes.TrueDiv):
orig_op, new_op = true_div, sub
pow_factors = []
# Rewrite pow-s having the same base for each different base
# E.g.: a^x * a^y --> a^(x+y)
for base in can_rewrite:
exponents = [n.owner.inputs[1] for n in pow_nodes[base]]
new_node = base ** new_op(*exponents)
if new_node.dtype != node.outputs[0].dtype:
new_node = cast(new_node, dtype=node.outputs[0].dtype)
pow_factors.append(new_node)
# Don't forget about those sole pow-s that couldn't be rewriten
sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite]
# Combine the rewritten pow-s and other, non-pow factors of the original Mul
# E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0:
new_out = orig_op(*pow_factors, *sole_pows, *rest)
if new_out.dtype != node.outputs[0].dtype:
new_out = cast(new_out, dtype=node.outputs[0].dtype)
else:
# if all factors of the original mul were pows-s with the same base,
# we can get rid of the mul completely.
new_out = pow_factors[0]
return [new_out]
@register_stabilize
@register_specialize
@register_canonicalize
......
......@@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot():
)
def test_local_mul_exp_to_exp_add():
# Default and FAST_RUN modes put a Composite op into the final graph,
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
mode = get_default_mode().excluding("fusion").including("local_mul_exp_to_exp_add")
x = scalar("x")
y = scalar("y")
z = scalar("z")
w = scalar("w")
expx = exp(x)
expy = exp(y)
expz = exp(z)
expw = exp(w)
# e^x * e^y * e^z * e^w = e^(x+y+z+w)
op = expx * expy * expz * expw
f = function([x, y, z, w], op, mode)
pytensor.dprint(f)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# e^x * e^y * e^z / e^w = e^(x+y+z-w)
op = expx * expy * expz / expw
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
# e^x * e^y / e^z * e^w = e^(x+y-z+w)
op = expx * expy / expz * expw
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
# e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
op = expx / expy / expz
f = function([x, y, z], op, mode)
utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
# e^x * y * e^z * w = e^(x+z) * y * w
op = expx * y * expz * w
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6)
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# expect same for matrices as well
mx = matrix("mx")
my = matrix("my")
f = function([mx, my], exp(mx) * exp(my), mode, allow_input_downcast=True)
M1 = np.array([[1.0, 2.0], [3.0, 4.0]])
M2 = np.array([[5.0, 6.0], [7.0, 8.0]])
utt.assert_allclose(f(M1, M2), np.exp(M1 + M2))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# checking whether further rewrites can proceed after this one as one would expect
# e^x * e^(-x) = e^(x-x) = e^0 = 1
f = function([x], expx * exp(neg(x)), mode)
utt.assert_allclose(f(42), 1)
graph = f.maker.fgraph.toposort()
assert isinstance(graph[0].inputs[0], TensorConstant)
# e^x / e^x = e^(x-x) = e^0 = 1
f = function([x], expx / expx, mode)
utt.assert_allclose(f(42), 1)
graph = f.maker.fgraph.toposort()
assert isinstance(graph[0].inputs[0], TensorConstant)
def test_local_mul_pow_to_pow_add():
# Default and FAST_RUN modes put a Composite op into the final graph,
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
mode = (
get_default_mode()
.excluding("fusion")
.including("local_mul_exp_to_exp_add")
.including("local_mul_pow_to_pow_add")
)
x = scalar("x")
y = scalar("y")
z = scalar("z")
w = scalar("w")
v = scalar("v")
u = scalar("u")
t = scalar("t")
s = scalar("s")
a = scalar("a")
b = scalar("b")
c = scalar("c")
# 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
op = 2**x * 2**y * 2**z * 2**w
f = function([x, y, z, w], op, mode)
utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t
f = function([x, y, z, w, v, u, t, s, a, b, c], op, mode)
utt.assert_allclose(
f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5),
2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11,
)
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 3
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Pow)]) == 4
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
op = 2**x / 2**y * (a**z / a**w)
f = function([x, y, z, w, a], op, mode)
utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Sub)]) == 2
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
# a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
op = a**x * a**y * exp(z) * exp(w)
f = function([x, y, z, w, a], op, mode)
utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6))
graph = f.maker.fgraph.toposort()
assert all(isinstance(n.op, Elemwise) for n in graph)
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 2
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
def test_local_expm1():
x = matrix("x")
u = scalar("u")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论