提交 271c2463 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: Jesse Grabowski

Rewrite scalar solve to division

上级 cf860fa6
...@@ -47,8 +47,10 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -47,8 +47,10 @@ from pytensor.tensor.rewriting.basic import (
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve,
Solve, Solve,
SolveBase, SolveBase,
SolveTriangular,
_bilinear_solve_discrete_lyapunov, _bilinear_solve_discrete_lyapunov,
block_diag, block_diag,
cholesky, cholesky,
...@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): ...@@ -908,6 +910,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
return None return None
[input] = node.inputs [input] = node.inputs
# Check if input is a (1, 1) matrix
if all(input.type.broadcastable[-2:]):
return [pt.sqrt(input)]
# Check for use of pt.diag first # Check for use of pt.diag first
if ( if (
input.owner input.owner
...@@ -1020,3 +1027,42 @@ def slogdet_specialization(fgraph, node): ...@@ -1020,3 +1027,42 @@ def slogdet_specialization(fgraph, node):
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
} }
return replacements return replacements
@register_stabilize
@register_canonicalize
@node_rewriter([Blockwise])
def scalar_solve_to_division(fgraph, node):
"""
Replace solve(a, b) with b / a if a is a (1, 1) matrix
"""
core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None
a, b = node.inputs
old_out = node.outputs[0]
if not all(a.broadcastable[-2:]):
return None
# Special handling for different types of solve
match core_op:
case SolveTriangular():
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
new_out = b / a if not core_op.unit_diagonal else b
case CholeskySolve():
new_out = b / a**2
case Solve():
new_out = b / a
case _:
raise NotImplementedError(
f"Unsupported core_op type: {type(core_op)} in scalar_solve_to_divison"
)
if core_op.b_ndim == 1:
new_out = new_out.squeeze(-1)
copy_stack_trace(old_out, new_out)
return [new_out]
...@@ -29,6 +29,7 @@ from pytensor.tensor.rewriting.linalg import inv_as_solve ...@@ -29,6 +29,7 @@ from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
BlockDiagonal, BlockDiagonal,
Cholesky, Cholesky,
CholeskySolve,
Solve, Solve,
SolveBase, SolveBase,
SolveTriangular, SolveTriangular,
...@@ -920,14 +921,6 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): ...@@ -920,14 +921,6 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
nodes = f_rewritten.maker.fgraph.apply_nodes nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes) assert any(isinstance(node.op, Cholesky) for node in nodes)
# Case 2 : eye is degenerate
x = pt.scalar("x")
y = pt.eye(1) * x
z_cholesky = pt.linalg.cholesky(y)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)
def test_slogdet_specialization(): def test_slogdet_specialization():
x, a = pt.dmatrix("x"), np.random.rand(20, 20) x, a = pt.dmatrix("x"), np.random.rand(20, 20)
...@@ -993,3 +986,37 @@ def test_slogdet_specialization(): ...@@ -993,3 +986,37 @@ def test_slogdet_specialization():
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes) assert not any(isinstance(node.op, SLogDet) for node in nodes)
@pytest.mark.parametrize(
"Op, fn",
[
(Solve, pt.linalg.solve),
(SolveTriangular, pt.linalg.solve_triangular),
(CholeskySolve, pt.linalg.cho_solve),
],
)
def test_scalar_solve_to_division_rewrite(Op, fn):
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))
a = pt.dmatrix("a", shape=(1, 1))
b = pt.dvector("b")
if Op is CholeskySolve:
# cho_solve expects a tuple (c, lower) as the first input
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
else:
c = fn(a, b, b_ndim=1)
f = function([a, b], c, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Op) for node in nodes)
a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)
c_val = np.linalg.solve(a_val, b_val)
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论