提交 16220296 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize log(prod(x)) -> sum(log(x)) rewrite

上级 060d85f3
......@@ -14,7 +14,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.graph.rewriting.unify import OpPattern
from pytensor.scalar.basic import Abs, Log, Mul, Sign
from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
......@@ -319,27 +319,41 @@ def local_det_chol(fgraph, node):
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_prod_sqr(fgraph, node):
"""
This utilizes a boolean `positive` tag on matrices.
"""
(x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod):
# we cannot always make this substitution because
# the prod might include negative terms
p = x.owner.inputs[0]
def local_log_prod_to_sum_log(fgraph, node):
"""Rewrite log(prod(x)) as sum(log(x)), when x is known to be positive."""
[p] = node.inputs
p_node = p.owner
if p_node is None:
return None
p_op = p_node.op
# p is the matrix we're reducing with prod
if getattr(p.tag, "positive", None) is True:
return [log(p).sum(axis=x.owner.op.axis)]
if isinstance(p_op, Prod):
x = p_node.inputs[0]
# TODO: The product of diagonals of a Cholesky(A) are also strictly positive
if (
x.owner is not None
and isinstance(x.owner.op, Elemwise)
and isinstance(x.owner.op.scalar_op, Abs | Sqr | Exp)
) or getattr(x.tag, "positive", False):
return [log(x).sum(axis=p_node.op.axis)]
# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
# Special case for log(abs(prod(x))) -> sum(log(abs(x))) that shows up in slogdet
elif isinstance(p_op, Elemwise) and isinstance(p_op.scalar_op, Abs):
[p] = p_node.inputs
p_node = p.owner
if p_node is not None and isinstance(p_node.op, Prod):
[x] = p.owner.inputs
return [log(abs(x)).sum(axis=p_node.op.axis)]
@register_specialize
@node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)])
......
......@@ -2,8 +2,10 @@ import numpy as np
import pytest
from pytensor import config, function, scan
from pytensor import tensor as pt
from pytensor.compile.mode import get_default_mode
from pytensor.gradient import grad
from pytensor.graph import rewrite_graph
from pytensor.scan.op import Scan
from pytensor.tensor._linalg.solve.rewriting import (
reuse_decomposition_multiple_solves,
......@@ -23,6 +25,7 @@ from pytensor.tensor.slinalg import (
SolveTriangular,
)
from pytensor.tensor.type import tensor
from tests.unittest_tools import assert_equal_computations
class DecompSolveOpCounter:
......@@ -213,3 +216,70 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
pytest.param(
lambda x: pt.log(pt.prod(pt.abs(x))),
lambda x: pt.sum(pt.log(pt.abs(x))),
id="log_prod_abs",
),
pytest.param(
lambda x: pt.log(pt.prod(pt.exp(x))), lambda x: pt.sum(x), id="log_prod_exp"
),
pytest.param(
lambda x: pt.log(pt.prod(x**2)),
lambda x: pt.sum(pt.log(pt.sqr(x))),
id="log_prod_sqr",
),
pytest.param(
lambda x: pt.log(pt.abs(pt.prod(x))),
lambda x: pt.sum(pt.log(pt.abs(x))),
id="log_abs_prod",
),
pytest.param(
lambda x: pt.log(pt.prod(pt.abs(x), axis=0)),
lambda x: pt.sum(pt.log(pt.abs(x)), axis=0),
id="log_prod_abs_axis0",
),
pytest.param(
lambda x: pt.log(pt.prod(pt.exp(x), axis=-1)),
lambda x: pt.sum(x, axis=-1),
id="log_prod_exp_axis-1",
),
],
)
def test_local_log_prod_to_sum_log(original_fn, expected_fn):
x = pt.tensor("x", shape=(3, 4))
out = original_fn(x)
expected = expected_fn(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])
@pytest.mark.parametrize(
"expected, pos_tag",
[
pytest.param(
lambda x: pt.sum(pt.log(x)),
True,
id="local_log_prod_to_sum_log_positive_tag",
),
pytest.param(
lambda x: pt.log(pt.prod(x)),
False,
id="local_log_prod_to_sum_log_no_rewrite",
),
],
)
def test_local_log_prod_to_sum_log_positive_tag(expected, pos_tag):
x = pt.tensor("x", shape=(3, 4))
if pos_tag:
x.tag.positive = True
out = pt.log(pt.prod(x))
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected(x)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论