提交 16220296 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize log(prod(x)) -> sum(log(x)) rewrite

上级 060d85f3
...@@ -14,7 +14,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -14,7 +14,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.rewriting.unify import OpPattern
from pytensor.scalar.basic import Abs, Log, Mul, Sign from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
AllocDiag, AllocDiag,
ExtractDiag, ExtractDiag,
...@@ -319,27 +319,41 @@ def local_det_chol(fgraph, node): ...@@ -319,27 +319,41 @@ def local_det_chol(fgraph, node):
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
@register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@node_rewriter([log]) @node_rewriter([log])
def local_log_prod_sqr(fgraph, node): def local_log_prod_to_sum_log(fgraph, node):
""" """Rewrite log(prod(x)) as sum(log(x)), when x is known to be positive."""
This utilizes a boolean `positive` tag on matrices. [p] = node.inputs
""" p_node = p.owner
(x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod): if p_node is None:
# we cannot always make this substitution because return None
# the prod might include negative terms
p = x.owner.inputs[0] p_op = p_node.op
# p is the matrix we're reducing with prod if isinstance(p_op, Prod):
if getattr(p.tag, "positive", None) is True: x = p_node.inputs[0]
return [log(p).sum(axis=x.owner.op.axis)]
# TODO: The product of diagonals of a Cholesky(A) are also strictly positive
if (
x.owner is not None
and isinstance(x.owner.op, Elemwise)
and isinstance(x.owner.op.scalar_op, Abs | Sqr | Exp)
) or getattr(x.tag, "positive", False):
return [log(x).sum(axis=p_node.op.axis)]
# TODO: have a reduction like prod and sum that simply # TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication. # returns the sign of the prod multiplication.
# Special case for log(abs(prod(x))) -> sum(log(abs(x))) that shows up in slogdet
elif isinstance(p_op, Elemwise) and isinstance(p_op.scalar_op, Abs):
[p] = p_node.inputs
p_node = p.owner
if p_node is not None and isinstance(p_node.op, Prod):
[x] = p.owner.inputs
return [log(abs(x)).sum(axis=p_node.op.axis)]
@register_specialize @register_specialize
@node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)]) @node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)])
......
...@@ -2,8 +2,10 @@ import numpy as np ...@@ -2,8 +2,10 @@ import numpy as np
import pytest import pytest
from pytensor import config, function, scan from pytensor import config, function, scan
from pytensor import tensor as pt
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import rewrite_graph
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.tensor._linalg.solve.rewriting import ( from pytensor.tensor._linalg.solve.rewriting import (
reuse_decomposition_multiple_solves, reuse_decomposition_multiple_solves,
...@@ -23,6 +25,7 @@ from pytensor.tensor.slinalg import ( ...@@ -23,6 +25,7 @@ from pytensor.tensor.slinalg import (
SolveTriangular, SolveTriangular,
) )
from pytensor.tensor.type import tensor from pytensor.tensor.type import tensor
from tests.unittest_tools import assert_equal_computations
class DecompSolveOpCounter: class DecompSolveOpCounter:
...@@ -213,3 +216,70 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): ...@@ -213,3 +216,70 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1 = fn_opt(A_test, x0_test) resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-4 rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol) np.testing.assert_allclose(resx0, resx1, rtol=rtol)
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
pytest.param(
lambda x: pt.log(pt.prod(pt.abs(x))),
lambda x: pt.sum(pt.log(pt.abs(x))),
id="log_prod_abs",
),
pytest.param(
lambda x: pt.log(pt.prod(pt.exp(x))), lambda x: pt.sum(x), id="log_prod_exp"
),
pytest.param(
lambda x: pt.log(pt.prod(x**2)),
lambda x: pt.sum(pt.log(pt.sqr(x))),
id="log_prod_sqr",
),
pytest.param(
lambda x: pt.log(pt.abs(pt.prod(x))),
lambda x: pt.sum(pt.log(pt.abs(x))),
id="log_abs_prod",
),
pytest.param(
lambda x: pt.log(pt.prod(pt.abs(x), axis=0)),
lambda x: pt.sum(pt.log(pt.abs(x)), axis=0),
id="log_prod_abs_axis0",
),
pytest.param(
lambda x: pt.log(pt.prod(pt.exp(x), axis=-1)),
lambda x: pt.sum(x, axis=-1),
id="log_prod_exp_axis-1",
),
],
)
def test_local_log_prod_to_sum_log(original_fn, expected_fn):
x = pt.tensor("x", shape=(3, 4))
out = original_fn(x)
expected = expected_fn(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])
@pytest.mark.parametrize(
"expected, pos_tag",
[
pytest.param(
lambda x: pt.sum(pt.log(x)),
True,
id="local_log_prod_to_sum_log_positive_tag",
),
pytest.param(
lambda x: pt.log(pt.prod(x)),
False,
id="local_log_prod_to_sum_log_no_rewrite",
),
],
)
def test_local_log_prod_to_sum_log_positive_tag(expected, pos_tag):
x = pt.tensor("x", shape=(3, 4))
if pos_tag:
x.tag.positive = True
out = pt.log(pt.prod(x))
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected(x)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论