提交 9c06de2b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite rewrite for solve with batched b

上级 8351f902
import logging import logging
from typing import cast from typing import cast
from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
...@@ -13,7 +13,14 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -13,7 +13,14 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
cholesky,
solve,
solve_triangular,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node): ...@@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node):
] ]
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
Only the last two dimensions of `b` and the output are swapped.
"""
core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None
if node.op.core_op.b_ndim != 1:
return None
[a, b] = node.inputs
# Check `b` is actually batched
if b.type.ndim == 1:
return None
# Check `a` is a matrix (possibly with degenerate dims on the left)
a_bcast_batch_dims = a.type.broadcastable[:-2]
if not all(a_bcast_batch_dims):
return None
# We squeeze degenerate dims, any that are still needed will be introduced by the new_solve
elif len(a_bcast_batch_dims):
a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims))))
# Recreate solve Op with b_ndim=2
props = core_op._props_dict()
props["b_ndim"] = 2
new_core_op = type(core_op)(**props)
matrix_b_solve = Blockwise(new_core_op)
# Apply the rewrite
new_solve = _T(matrix_b_solve(a, _T(b)))
old_solve = node.outputs[0]
copy_stack_trace(old_solve, new_solve)
return [new_solve]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
from functools import partial from functools import partial
import numpy as np import numpy as np
import numpy.linalg
import pytest import pytest
import scipy.linalg import scipy.linalg
from numpy.testing import assert_allclose from numpy.testing import assert_allclose
...@@ -17,7 +16,16 @@ from pytensor.tensor.elemwise import DimShuffle ...@@ -17,7 +16,16 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
SolveTriangular,
cho_solve,
cholesky,
solve,
solve_triangular,
)
from pytensor.tensor.type import dmatrix, matrix, tensor, vector from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
...@@ -231,3 +239,70 @@ def test_local_det_chol(): ...@@ -231,3 +239,70 @@ def test_local_det_chol():
f = function([X], [L, det_X, X]) f = function([X], [L, det_X, X])
nodes = f.maker.fgraph.toposort() nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes) assert not any(isinstance(node, Det) for node in nodes)
class TestBatchedVectorBSolveToMatrixBSolve:
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"
@staticmethod
def any_vector_b_solve(fn):
return any(
(
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveBase)
and node.op.core_op.b_ndim == 1
)
for node in fn.maker.fgraph.apply_nodes
)
@pytest.mark.parametrize("solve_op", (solve, solve_triangular, cho_solve))
def test_valid_cases(self, solve_op):
rng = np.random.default_rng(sum(map(ord, solve_op.__name__)))
a = tensor(shape=(None, None))
b = tensor(shape=(None, None, None))
if solve_op is cho_solve:
# cho_solves expects a tuple (a, lower) as the first input
out = solve_op((a, True), b, b_ndim=1)
else:
out = solve_op(a, b, b_ndim=1)
mode = get_default_mode().excluding(self.rewrite_name)
ref_fn = pytensor.function([a, b], out, mode=mode)
assert self.any_vector_b_solve(ref_fn)
mode = get_default_mode().including(self.rewrite_name)
opt_fn = pytensor.function([a, b], out, mode=mode)
assert not self.any_vector_b_solve(opt_fn)
test_a = rng.normal(size=(3, 3)).astype(config.floatX)
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
np.testing.assert_allclose(
opt_fn(test_a, test_b),
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)
def test_invalid_batched_a(self):
rng = np.random.default_rng(sum(map(ord, self.rewrite_name)))
# Rewrite is not applicable if a has batched dims
a = tensor(shape=(None, None, None))
b = tensor(shape=(None, None, None))
out = solve(a, b, b_ndim=1)
mode = get_default_mode().including(self.rewrite_name)
opt_fn = pytensor.function([a, b], out, mode=mode)
assert self.any_vector_b_solve(opt_fn)
ref_fn = np.vectorize(np.linalg.solve, signature="(m,m),(m)->(m)")
test_a = rng.normal(size=(5, 3, 3)).astype(config.floatX)
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
np.testing.assert_allclose(
opt_fn(test_a, test_b),
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)
...@@ -257,6 +257,8 @@ class BlockwiseOpTester: ...@@ -257,6 +257,8 @@ class BlockwiseOpTester:
np.testing.assert_allclose( np.testing.assert_allclose(
pt_func(*vec_inputs_testvals), pt_func(*vec_inputs_testvals),
np_func(*vec_inputs_testvals), np_func(*vec_inputs_testvals),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
atol=1e-7 if config.floatX == "float64" else 1e-5,
) )
def test_grad(self): def test_grad(self):
...@@ -288,6 +290,7 @@ class BlockwiseOpTester: ...@@ -288,6 +290,7 @@ class BlockwiseOpTester:
np.testing.assert_allclose( np.testing.assert_allclose(
pt_out, pt_out,
np_out, np_out,
rtol=1e-7 if config.floatX == "float64" else 1e-5,
atol=1e-6 if config.floatX == "float64" else 1e-5, atol=1e-6 if config.floatX == "float64" else 1e-5,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论