提交 9c06de2b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite rewrite for solve with batched b

上级 8351f902
import logging
from typing import cast
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
......@@ -13,7 +13,14 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
cholesky,
solve,
solve_triangular,
)
logger = logging.getLogger(__name__)
......@@ -131,6 +138,52 @@ def generic_solve_to_solve_triangular(fgraph, node):
]
@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
Only the last two dimensions of `b` and the output are swapped.
"""
core_op = node.op.core_op
if not isinstance(core_op, SolveBase):
return None
if node.op.core_op.b_ndim != 1:
return None
[a, b] = node.inputs
# Check `b` is actually batched
if b.type.ndim == 1:
return None
# Check `a` is a matrix (possibly with degenerate dims on the left)
a_bcast_batch_dims = a.type.broadcastable[:-2]
if not all(a_bcast_batch_dims):
return None
# We squeeze degenerate dims, any that are still needed will be introduced by the new_solve
elif len(a_bcast_batch_dims):
a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims))))
# Recreate solve Op with b_ndim=2
props = core_op._props_dict()
props["b_ndim"] = 2
new_core_op = type(core_op)(**props)
matrix_b_solve = Blockwise(new_core_op)
# Apply the rewrite
new_solve = _T(matrix_b_solve(a, _T(b)))
old_solve = node.outputs[0]
copy_stack_trace(old_solve, new_solve)
return [new_solve]
@register_canonicalize
@register_stabilize
@register_specialize
......
from functools import partial
import numpy as np
import numpy.linalg
import pytest
import scipy.linalg
from numpy.testing import assert_allclose
......@@ -17,7 +16,16 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.slinalg import (
Cholesky,
Solve,
SolveBase,
SolveTriangular,
cho_solve,
cholesky,
solve,
solve_triangular,
)
from pytensor.tensor.type import dmatrix, matrix, tensor, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
......@@ -231,3 +239,70 @@ def test_local_det_chol():
f = function([X], [L, det_X, X])
nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)
class TestBatchedVectorBSolveToMatrixBSolve:
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"
@staticmethod
def any_vector_b_solve(fn):
return any(
(
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, SolveBase)
and node.op.core_op.b_ndim == 1
)
for node in fn.maker.fgraph.apply_nodes
)
@pytest.mark.parametrize("solve_op", (solve, solve_triangular, cho_solve))
def test_valid_cases(self, solve_op):
rng = np.random.default_rng(sum(map(ord, solve_op.__name__)))
a = tensor(shape=(None, None))
b = tensor(shape=(None, None, None))
if solve_op is cho_solve:
# cho_solves expects a tuple (a, lower) as the first input
out = solve_op((a, True), b, b_ndim=1)
else:
out = solve_op(a, b, b_ndim=1)
mode = get_default_mode().excluding(self.rewrite_name)
ref_fn = pytensor.function([a, b], out, mode=mode)
assert self.any_vector_b_solve(ref_fn)
mode = get_default_mode().including(self.rewrite_name)
opt_fn = pytensor.function([a, b], out, mode=mode)
assert not self.any_vector_b_solve(opt_fn)
test_a = rng.normal(size=(3, 3)).astype(config.floatX)
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
np.testing.assert_allclose(
opt_fn(test_a, test_b),
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)
def test_invalid_batched_a(self):
rng = np.random.default_rng(sum(map(ord, self.rewrite_name)))
# Rewrite is not applicable if a has batched dims
a = tensor(shape=(None, None, None))
b = tensor(shape=(None, None, None))
out = solve(a, b, b_ndim=1)
mode = get_default_mode().including(self.rewrite_name)
opt_fn = pytensor.function([a, b], out, mode=mode)
assert self.any_vector_b_solve(opt_fn)
ref_fn = np.vectorize(np.linalg.solve, signature="(m,m),(m)->(m)")
test_a = rng.normal(size=(5, 3, 3)).astype(config.floatX)
test_b = rng.normal(size=(7, 5, 3)).astype(config.floatX)
np.testing.assert_allclose(
opt_fn(test_a, test_b),
ref_fn(test_a, test_b),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
)
......@@ -257,6 +257,8 @@ class BlockwiseOpTester:
np.testing.assert_allclose(
pt_func(*vec_inputs_testvals),
np_func(*vec_inputs_testvals),
rtol=1e-7 if config.floatX == "float64" else 1e-5,
atol=1e-7 if config.floatX == "float64" else 1e-5,
)
def test_grad(self):
......@@ -288,6 +290,7 @@ class BlockwiseOpTester:
np.testing.assert_allclose(
pt_out,
np_out,
rtol=1e-7 if config.floatX == "float64" else 1e-5,
atol=1e-6 if config.floatX == "float64" else 1e-5,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论