提交 815671d5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix shape errors in `scalar_solve_to_division`

上级 80acf202
......@@ -1046,11 +1046,15 @@ def scalar_solve_to_division(fgraph, node):
if not all(a.broadcastable[-2:]):
return None
if core_op.b_ndim == 1:
# Convert b to a column matrix
b = b[..., None]
# Special handling for different types of solve
match core_op:
case SolveTriangular():
# Corner case: if user asked for a triangular solve with a unit diagonal, a is taken to be 1
new_out = b / a if not core_op.unit_diagonal else b
new_out = b / a if not core_op.unit_diagonal else pt.second(a, b)
case CholeskySolve():
new_out = b / a**2
case Solve():
......@@ -1061,6 +1065,7 @@ def scalar_solve_to_division(fgraph, node):
)
if core_op.b_ndim == 1:
# Squeeze away the column dimension added earlier
new_out = new_out.squeeze(-1)
copy_stack_trace(old_out, new_out)
......
......@@ -10,6 +10,7 @@ from pytensor import function
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
......@@ -989,34 +990,73 @@ def test_slogdet_specialization():
@pytest.mark.parametrize(
"Op, fn",
"a_batch_shape", [(), (5,)], ids=lambda x: f"a_batch_shape={x}"
)
@pytest.mark.parametrize(
"b_batch_shape", [(), (5,)], ids=lambda x: f"b_batch_shape={x}"
)
@pytest.mark.parametrize("b_ndim", (1, 2), ids=lambda x: f"b_ndim={x}")
@pytest.mark.parametrize(
"op, fn, extra_kwargs",
[
(Solve, pt.linalg.solve),
(SolveTriangular, pt.linalg.solve_triangular),
(CholeskySolve, pt.linalg.cho_solve),
(Solve, pt.linalg.solve, {}),
(SolveTriangular, pt.linalg.solve_triangular, {}),
(SolveTriangular, pt.linalg.solve_triangular, {"unit_diagonal": True}),
(CholeskySolve, pt.linalg.cho_solve, {}),
],
)
def test_scalar_solve_to_division_rewrite(Op, fn):
rng = np.random.default_rng(sum(map(ord, "scalar_solve_to_division_rewrite")))
def test_scalar_solve_to_division_rewrite(
op, fn, extra_kwargs, b_ndim, a_batch_shape, b_batch_shape
):
def solve_op_in_graph(graph):
return any(
isinstance(var.owner.op, SolveBase)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, SolveBase)
)
for var in ancestors(graph)
if var.owner
)
rng = np.random.default_rng(
[
sum(map(ord, "scalar_solve_to_division_rewrite")),
b_ndim,
*a_batch_shape,
1,
*b_batch_shape,
]
)
a = pt.dmatrix("a", shape=(1, 1))
b = pt.dvector("b")
a = pt.tensor("a", shape=(*a_batch_shape, 1, 1), dtype="float64")
b = pt.tensor("b", shape=(*b_batch_shape, *([None] * b_ndim)), dtype="float64")
if Op is CholeskySolve:
if op is CholeskySolve:
# cho_solve expects a tuple (c, lower) as the first input
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=1)
c = fn((pt.linalg.cholesky(a), True), b, b_ndim=b_ndim, **extra_kwargs)
else:
c = fn(a, b, b_ndim=1)
c = fn(a, b, b_ndim=b_ndim, **extra_kwargs)
assert solve_op_in_graph([c])
f = function([a, b], c, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not solve_op_in_graph(f.maker.fgraph.outputs)
a_val = rng.normal(size=(*a_batch_shape, 1, 1)).astype(pytensor.config.floatX)
b_core_shape = (1, 5) if b_ndim == 2 else (1,)
b_val = rng.normal(size=(*b_batch_shape, *b_core_shape)).astype(
pytensor.config.floatX
)
assert not any(isinstance(node.op, Op) for node in nodes)
if op is CholeskySolve:
# Avoid sign ambiguity in solve
a_val = a_val**2
a_val = rng.normal(size=(1, 1)).astype(pytensor.config.floatX)
b_val = rng.normal(size=(1,)).astype(pytensor.config.floatX)
if extra_kwargs.get("unit_diagonal", False):
a_val = np.ones_like(a_val)
c_val = np.linalg.solve(a_val, b_val)
signature = "(n,m),(m)->(n)" if b_ndim == 1 else "(n,m),(m,k)->(n,k)"
c_val = np.vectorize(np.linalg.solve, signature=signature)(a_val, b_val)
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论