提交 a314476f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Generalize determinant from factorization rewrites

上级 16220296
...@@ -23,6 +23,7 @@ from pytensor.tensor.basic import ( ...@@ -23,6 +23,7 @@ from pytensor.tensor.basic import (
concatenate, concatenate,
diag, diag,
diagonal, diagonal,
ones,
) )
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
...@@ -46,9 +47,12 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -46,9 +47,12 @@ from pytensor.tensor.rewriting.basic import (
) )
from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
LU,
QR,
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
LUFactor,
Solve, Solve,
SolveBase, SolveBase,
SolveTriangular, SolveTriangular,
...@@ -65,6 +69,10 @@ logger = logging.getLogger(__name__) ...@@ -65,6 +69,10 @@ logger = logging.getLogger(__name__)
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
def matrix_diagonal_product(x):
return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1)
@register_canonicalize @register_canonicalize
@node_rewriter([BlockDiagonal]) @node_rewriter([BlockDiagonal])
def fuse_blockdiagonal(fgraph, node): def fuse_blockdiagonal(fgraph, node):
...@@ -303,22 +311,6 @@ def cholesky_ldotlt(fgraph, node): ...@@ -303,22 +311,6 @@ def cholesky_ldotlt(fgraph, node):
return [r] return [r]
@register_stabilize
@register_specialize
@node_rewriter([det])
def local_det_chol(fgraph, node):
"""
If we have det(X) and there is already an L=cholesky(X)
floating around, then we can use prod(diag(L)) to get the determinant.
"""
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
L = cl.outputs[0]
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@node_rewriter([log]) @node_rewriter([log])
...@@ -480,6 +472,127 @@ def _find_diag_from_eye_mul(potential_mul_input): ...@@ -480,6 +472,127 @@ def _find_diag_from_eye_mul(potential_mul_input):
return eye_input, non_eye_inputs return eye_input, non_eye_inputs
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([det])
def det_of_matrix_factorized_elsewhere(fgraph, node):
"""
If we have det(X) or abs(det(X)) and there is already a nice decomposition(X) floating around,
use it to compute it more cheaply
"""
[det] = node.outputs
[x] = node.inputs
sign_not_needed = all(
isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, (Abs, Sqr))
for client, _ in fgraph.clients[det]
)
new_det = None
for client, _ in fgraph.clients[x]:
core_op = client.op.core_op if isinstance(client.op, Blockwise) else client.op
match core_op:
case Cholesky():
L = client.outputs[0]
new_det = matrix_diagonal_product(L) ** 2
case LU():
U = client.outputs[-1]
new_det = matrix_diagonal_product(U)
case LUFactor():
LU_packed = client.outputs[0]
new_det = matrix_diagonal_product(LU_packed)
case _:
if not sign_not_needed:
continue
match core_op:
case SVD():
lmbda = (
client.outputs[1]
if core_op.compute_uv
else client.outputs[0]
)
new_det = prod(lmbda, axis=-1)
case QR():
R = client.outputs[-1]
# if mode == "economic", R may not be square and this rewrite could hide a shape error
# That's why it's tagged as `shape_unsafe`
new_det = matrix_diagonal_product(R)
if new_det is not None:
# found a match
break
else: # no-break (i.e., no-match)
return None
[det] = node.outputs
copy_stack_trace(det, new_det)
return [new_det]
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter(tracks=[det])
def det_of_factorized_matrix(fgraph, node):
"""Introduce special forms for det(decomposition(X)).
Some cases are only known up to a sign change such as det(QR(X)),
and are only introduced if the determinant sign is discarded downstream (e.g., abs, sqr)
"""
[det] = node.outputs
[x] = node.inputs
sign_not_needed = all(
isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, (Abs, Sqr))
for client, _ in fgraph.clients[det]
)
x_node = x.owner
if x_node is None:
return None
x_op = x_node.op
core_op = x_op.core_op if isinstance(x_op, Blockwise) else x_op
new_det = None
match core_op:
case Cholesky():
new_det = matrix_diagonal_product(x)
case LU():
if x is x_node.outputs[-2]:
# x is L
new_det = ones(x.shape[:-2], dtype=det.dtype)
elif x is x_node.outputs[-1]:
# x is U
new_det = matrix_diagonal_product(x)
case SVD():
if not core_op.compute_uv or x is x_node.outputs[1]:
# x is lambda
new_det = prod(x, axis=-1)
elif sign_not_needed:
# x is either U or Vt and sign is discarded downstream
new_det = ones(x.shape[:-2], dtype=det.dtype)
case QR():
# if mode == "economic", Q/R may not be square and this rewrite could hide a shape error
# That's why it's tagged as `shape_unsafe`
if x is x_node.outputs[-1]:
# x is R
new_det = matrix_diagonal_product(x)
elif (
sign_not_needed
and core_op.mode in ("economic", "full")
and x is x_node.outputs[0]
):
# x is Q and sign is discarded downstream
new_det = ones(x.shape[:-2], dtype=det.dtype)
if new_det is None:
return None
copy_stack_trace(det, new_det)
return [new_det]
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe") @register_stabilize("shape_unsafe")
@node_rewriter([det]) @node_rewriter([det])
......
...@@ -17,6 +17,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import ( ...@@ -17,6 +17,7 @@ from pytensor.tensor._linalg.solve.tridiagonal import (
) )
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.linalg import solve from pytensor.tensor.linalg import solve
from pytensor.tensor.nlinalg import det
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
Cholesky, Cholesky,
CholeskySolve, CholeskySolve,
...@@ -283,3 +284,181 @@ def test_local_log_prod_to_sum_log_positive_tag(expected, pos_tag): ...@@ -283,3 +284,181 @@ def test_local_log_prod_to_sum_log_positive_tag(expected, pos_tag):
rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected(x)]) assert_equal_computations([rewritten], [expected(x)])
@pytest.mark.parametrize(
"decomp_fn, expected_fn",
[
pytest.param(
lambda x: pt.linalg.cholesky(x),
lambda x: pt.sqr(pt.prod(pt.diag(pt.linalg.cholesky(x)), axis=0)),
id="cholesky",
),
pytest.param(
lambda x: pt.linalg.lu(x)[-1],
lambda x: pt.prod(pt.extract_diag(pt.linalg.lu(x)[-1]), axis=0),
id="lu",
),
pytest.param(
lambda x: pt.linalg.lu_factor(x)[0],
lambda x: pt.prod(pt.extract_diag(pt.linalg.lu_factor(x)[0]), axis=0),
id="lu_factor",
),
],
)
def test_det_of_matrix_factorized_elsewhere(decomp_fn, expected_fn):
x = pt.tensor("x", shape=(3, 3))
decomp_var = decomp_fn(x)
d = det(x)
decomp_var, d = rewrite_graph(
[decomp_var, d], include=["canonicalize", "stabilize", "specialize"]
)
assert_equal_computations([decomp_var], [decomp_fn(x)])
assert_equal_computations([d], [expected_fn(x)])
@pytest.mark.parametrize(
"decomp_fn, sign_op, expected_fn",
[
pytest.param(
lambda x: pt.linalg.svd(x, compute_uv=True)[0],
pt.abs,
lambda x: pt.prod(pt.linalg.svd(x, compute_uv=True)[1], axis=0),
id="svd_abs",
),
pytest.param(
lambda x: pt.linalg.svd(x, compute_uv=False),
pt.abs,
lambda x: pt.prod(pt.linalg.svd(x, compute_uv=False), axis=0),
id="svd_no_uv_abs",
),
pytest.param(
lambda x: pt.linalg.qr(x)[0],
pt.abs,
lambda x: pt.prod(
pt.diagonal(pt.linalg.qr(x)[1], axis1=-2, axis2=-1), axis=-1
),
id="qr_abs",
),
pytest.param(
lambda x: pt.linalg.svd(x, compute_uv=True)[0],
pt.sqr,
lambda x: pt.prod(pt.linalg.svd(x, compute_uv=True)[1], axis=0),
id="svd_sqr",
),
pytest.param(
lambda x: pt.linalg.svd(x, compute_uv=False),
pt.sqr,
lambda x: pt.prod(pt.linalg.svd(x, compute_uv=False), axis=0),
id="svd_no_uv_sqr",
),
pytest.param(
lambda x: pt.linalg.qr(x)[0],
pt.sqr,
lambda x: pt.prod(
pt.diagonal(pt.linalg.qr(x)[1], axis1=-2, axis2=-1), axis=-1
),
id="qr_sqr",
),
],
)
def test_det_of_matrix_factorized_elsewhere_abs(decomp_fn, sign_op, expected_fn):
x = pt.tensor("x", shape=(3, 3))
decomp_var = decomp_fn(x)
d = sign_op(det(x))
decomp_var, d = rewrite_graph(
[decomp_var, d], include=["canonicalize", "stabilize", "specialize"]
)
assert_equal_computations([decomp_var], [decomp_fn(x)])
assert_equal_computations([d], [sign_op(expected_fn(x))])
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
pytest.param(
lambda x: det(pt.linalg.cholesky(x)),
lambda x: pt.prod(
pt.diagonal(pt.linalg.cholesky(x), axis1=-2, axis2=-1), axis=-1
),
id="det_cholesky",
),
pytest.param(
lambda x: det(pt.linalg.lu(x)[-1]),
lambda x: pt.prod(
pt.diagonal(pt.linalg.lu(x)[-1], axis1=-2, axis2=-1), axis=-1
),
id="det_lu_U",
),
pytest.param(
lambda x: det(pt.linalg.lu(x)[-2]),
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
id="det_lu_L",
),
],
)
def test_det_of_factorized_matrix(original_fn, expected_fn):
x = pt.tensor("x", shape=(3, 3))
out = original_fn(x)
expected = expected_fn(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
pytest.param(
lambda x: pt.abs(det(pt.linalg.svd(x, compute_uv=True)[0])),
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
id="abs_det_svd_U",
),
pytest.param(
lambda x: pt.abs(det(pt.linalg.svd(x, compute_uv=True)[2])),
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
id="abs_det_svd_Vt",
),
pytest.param(
lambda x: pt.abs(det(pt.linalg.qr(x)[0])),
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
id="abs_det_qr_Q",
),
pytest.param(
lambda x: pt.sqr(det(pt.linalg.svd(x, compute_uv=True)[0])),
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
id="sqr_det_svd_U",
),
pytest.param(
lambda x: pt.sqr(det(pt.linalg.svd(x, compute_uv=True)[2])),
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
id="sqr_det_svd_Vt",
),
pytest.param(
lambda x: pt.sqr(det(pt.linalg.qr(x)[0])),
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
id="sqr_det_qr_Q",
),
pytest.param(
lambda x: det(pt.linalg.qr(x)[1]),
lambda x: pt.prod(
pt.diagonal(pt.linalg.qr(x)[1], axis1=-2, axis2=-1), axis=-1
),
id="det_qr_R",
),
pytest.param(
lambda x: det(pt.linalg.qr(x)[0]),
lambda x: det(pt.linalg.qr(x)[0]),
id="det_qr_Q_no_rewrite",
),
],
)
def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn):
x = pt.tensor("x", shape=(3, 3))
out = original_fn(x)
expected = expected_fn(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])
...@@ -309,14 +309,15 @@ def test_local_det_chol(): ...@@ -309,14 +309,15 @@ def test_local_det_chol():
det_X = pt.linalg.det(X) det_X = pt.linalg.det(X)
f = function([X], [L, det_X]) f = function([X], [L, det_X])
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)
nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)
# This previously raised an error (issue #392) # This previously raised an error (issue #392)
f = function([X], [L, det_X, X]) f = function([X], [L, det_X, X])
nodes = f.maker.fgraph.toposort() assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)
assert not any(isinstance(node, Det) for node in nodes)
# Test graph that only has det_X
f = function([X], [det_X])
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)
def test_psd_solve_with_chol(): def test_psd_solve_with_chol():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论