提交 a3eed0b4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Blockwise some linalg Ops by default

上级 7fb4e70a
...@@ -3764,7 +3764,7 @@ def stacklists(arg): ...@@ -3764,7 +3764,7 @@ def stacklists(arg):
return arg return arg
def swapaxes(y, axis1, axis2): def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor." "Swap the axes of a tensor."
y = as_tensor_variable(y) y = as_tensor_variable(y)
ndim = y.ndim ndim = y.ndim
......
...@@ -10,11 +10,13 @@ from pytensor.graph.op import Op ...@@ -10,11 +10,13 @@ from pytensor.graph.op import Op
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import math as tm from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
class MatrixPinv(Op): class MatrixPinv(Op):
__props__ = ("hermitian",) __props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"
def __init__(self, hermitian): def __init__(self, hermitian):
self.hermitian = hermitian self.hermitian = hermitian
...@@ -75,7 +77,7 @@ def pinv(x, hermitian=False): ...@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
solve op. solve op.
""" """
return MatrixPinv(hermitian=hermitian)(x) return Blockwise(MatrixPinv(hermitian=hermitian))(x)
class MatrixInverse(Op): class MatrixInverse(Op):
...@@ -93,6 +95,8 @@ class MatrixInverse(Op): ...@@ -93,6 +95,8 @@ class MatrixInverse(Op):
""" """
__props__ = () __props__ = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)
def __init__(self): def __init__(self):
pass pass
...@@ -150,7 +154,7 @@ class MatrixInverse(Op): ...@@ -150,7 +154,7 @@ class MatrixInverse(Op):
return shapes return shapes
inv = matrix_inverse = MatrixInverse() inv = matrix_inverse = Blockwise(MatrixInverse())
def matrix_dot(*args): def matrix_dot(*args):
...@@ -181,6 +185,8 @@ class Det(Op): ...@@ -181,6 +185,8 @@ class Det(Op):
""" """
__props__ = () __props__ = ()
gufunc_signature = "(m,m)->()"
gufunc_spec = ("numpy.linalg.det", 1, 1)
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -209,7 +215,7 @@ class Det(Op): ...@@ -209,7 +215,7 @@ class Det(Op):
return "Det" return "Det"
det = Det() det = Blockwise(Det())
class SLogDet(Op): class SLogDet(Op):
...@@ -218,6 +224,8 @@ class SLogDet(Op): ...@@ -218,6 +224,8 @@ class SLogDet(Op):
""" """
__props__ = () __props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -242,7 +250,7 @@ class SLogDet(Op): ...@@ -242,7 +250,7 @@ class SLogDet(Op):
return "SLogDet" return "SLogDet"
slogdet = SLogDet() slogdet = Blockwise(SLogDet())
class Eig(Op): class Eig(Op):
...@@ -252,6 +260,8 @@ class Eig(Op): ...@@ -252,6 +260,8 @@ class Eig(Op):
""" """
__props__: Tuple[str, ...] = () __props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -270,7 +280,7 @@ class Eig(Op): ...@@ -270,7 +280,7 @@ class Eig(Op):
return [(n,), (n, n)] return [(n,), (n, n)]
eig = Eig() eig = Blockwise(Eig())
class Eigh(Eig): class Eigh(Eig):
......
import logging import logging
import typing import typing
import warnings import warnings
from typing import TYPE_CHECKING, Literal, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
...@@ -13,6 +13,7 @@ from pytensor.graph.op import Op ...@@ -13,6 +13,7 @@ from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import math as atm from pytensor.tensor import math as atm
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
...@@ -48,6 +49,7 @@ class Cholesky(Op): ...@@ -48,6 +49,7 @@ class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also # TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "destructive", "on_error") __props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, on_error="raise"): def __init__(self, *, lower=True, on_error="raise"):
self.lower = lower self.lower = lower
...@@ -109,7 +111,7 @@ class Cholesky(Op): ...@@ -109,7 +111,7 @@ class Cholesky(Op):
def conjugate_solve_triangular(outer, inner): def conjugate_solve_triangular(outer, inner):
"""Computes L^{-T} P L^{-1} for lower-triangular L.""" """Computes L^{-T} P L^{-1} for lower-triangular L."""
solve_upper = SolveTriangular(lower=False) solve_upper = SolveTriangular(lower=False, b_ndim=2)
return solve_upper(outer.T, solve_upper(outer.T, inner.T).T) return solve_upper(outer.T, solve_upper(outer.T, inner.T).T)
s = conjugate_solve_triangular( s = conjugate_solve_triangular(
...@@ -128,7 +130,7 @@ class Cholesky(Op): ...@@ -128,7 +130,7 @@ class Cholesky(Op):
def cholesky(x, lower=True, on_error="raise"): def cholesky(x, lower=True, on_error="raise"):
return Cholesky(lower=lower, on_error=on_error)(x) return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
class SolveBase(Op): class SolveBase(Op):
...@@ -137,6 +139,7 @@ class SolveBase(Op): ...@@ -137,6 +139,7 @@ class SolveBase(Op):
__props__ = ( __props__ = (
"lower", "lower",
"check_finite", "check_finite",
"b_ndim",
) )
def __init__( def __init__(
...@@ -144,9 +147,16 @@ class SolveBase(Op): ...@@ -144,9 +147,16 @@ class SolveBase(Op):
*, *,
lower=False, lower=False,
check_finite=True, check_finite=True,
b_ndim,
): ):
self.lower = lower self.lower = lower
self.check_finite = check_finite self.check_finite = check_finite
assert b_ndim in (1, 2)
self.b_ndim = b_ndim
if b_ndim == 1:
self.gufunc_signature = "(m,m),(m)->(m)"
else:
self.gufunc_signature = "(m,m),(m,n)->(m,n)"
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
pass pass
...@@ -157,8 +167,8 @@ class SolveBase(Op): ...@@ -157,8 +167,8 @@ class SolveBase(Op):
if A.ndim != 2: if A.ndim != 2:
raise ValueError(f"`A` must be a matrix; got {A.type} instead.") raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
if b.ndim not in (1, 2): if b.ndim != self.b_ndim:
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.") raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices # Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve( o_dtype = scipy.linalg.solve(
...@@ -209,6 +219,16 @@ class SolveBase(Op): ...@@ -209,6 +219,16 @@ class SolveBase(Op):
return [A_bar, b_bar] return [A_bar, b_bar]
def _default_b_ndim(b, b_ndim):
if b_ndim is not None:
assert b_ndim in (1, 2)
return b_ndim
b = as_tensor_variable(b)
if b_ndim is None:
return min(b.ndim, 2) # By default assume the core case is a matrix
class CholeskySolve(SolveBase): class CholeskySolve(SolveBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs.setdefault("lower", True) kwargs.setdefault("lower", True)
...@@ -228,7 +248,7 @@ class CholeskySolve(SolveBase): ...@@ -228,7 +248,7 @@ class CholeskySolve(SolveBase):
raise NotImplementedError() raise NotImplementedError()
def cho_solve(c_and_lower, b, *, check_finite=True): def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None):
"""Solve the linear equations A x = b, given the Cholesky factorization of A. """Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters Parameters
...@@ -241,9 +261,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True): ...@@ -241,9 +261,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True):
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
""" """
A, lower = c_and_lower A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
CholeskySolve(lower=lower, check_finite=check_finite, b_ndim=b_ndim)
)(A, b)
class SolveTriangular(SolveBase): class SolveTriangular(SolveBase):
...@@ -254,6 +280,7 @@ class SolveTriangular(SolveBase): ...@@ -254,6 +280,7 @@ class SolveTriangular(SolveBase):
"unit_diagonal", "unit_diagonal",
"lower", "lower",
"check_finite", "check_finite",
"b_ndim",
) )
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
...@@ -291,6 +318,7 @@ def solve_triangular( ...@@ -291,6 +318,7 @@ def solve_triangular(
lower: bool = False, lower: bool = False,
unit_diagonal: bool = False, unit_diagonal: bool = False,
check_finite: bool = True, check_finite: bool = True,
b_ndim: Optional[int] = None,
) -> TensorVariable: ) -> TensorVariable:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix. """Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
...@@ -314,12 +342,19 @@ def solve_triangular( ...@@ -314,12 +342,19 @@ def solve_triangular(
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
""" """
return SolveTriangular( b_ndim = _default_b_ndim(b, b_ndim)
lower=lower, return Blockwise(
trans=trans, SolveTriangular(
unit_diagonal=unit_diagonal, lower=lower,
check_finite=check_finite, trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim=b_ndim,
)
)(a, b) )(a, b)
...@@ -332,6 +367,7 @@ class Solve(SolveBase): ...@@ -332,6 +367,7 @@ class Solve(SolveBase):
"assume_a", "assume_a",
"lower", "lower",
"check_finite", "check_finite",
"b_ndim",
) )
def __init__(self, *, assume_a="gen", **kwargs): def __init__(self, *, assume_a="gen", **kwargs):
...@@ -352,7 +388,15 @@ class Solve(SolveBase): ...@@ -352,7 +388,15 @@ class Solve(SolveBase):
) )
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): def solve(
a,
b,
*,
assume_a="gen",
lower=False,
check_finite=True,
b_ndim: Optional[int] = None,
):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the If the data matrix is known to be a particular type then supplying the
...@@ -375,9 +419,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): ...@@ -375,9 +419,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
Parameters Parameters
---------- ----------
a : (N, N) array_like a : (..., N, N) array_like
Square input data Square input data
b : (N, NRHS) array_like b : (..., N, NRHS) array_like
Input data for the right hand side. Input data for the right hand side.
lower : bool, optional lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default If True, only the data contained in the lower triangle of `a`. Default
...@@ -388,11 +432,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): ...@@ -388,11 +432,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional assume_a : str, optional
Valid entries are explained above. Valid entries are explained above.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
""" """
return Solve( b_ndim = _default_b_ndim(b, b_ndim)
lower=lower, return Blockwise(
check_finite=check_finite, Solve(
assume_a=assume_a, lower=lower,
check_finite=check_finite,
assume_a=assume_a,
b_ndim=b_ndim,
)
)(a, b) )(a, b)
......
...@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc): ...@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
], ],
) )
def test_Solve(A, x, lower, exc): def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower=lower)(A, x) g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) g_fg = FunctionGraph(outputs=g)
...@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc): ...@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
], ],
) )
def test_SolveTriangular(A, x, lower, exc): def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower=lower)(A, x) g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) g_fg = FunctionGraph(outputs=g)
......
...@@ -9,11 +9,12 @@ from pytensor import function ...@@ -9,11 +9,12 @@ from pytensor import function
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile import get_default_mode from pytensor.compile import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
...@@ -23,7 +24,7 @@ def test_rop_lop(): ...@@ -23,7 +24,7 @@ def test_rop_lop():
mx = matrix("mx") mx = matrix("mx")
mv = matrix("mv") mv = matrix("mv")
v = vector("v") v = vector("v")
y = matrix_inverse(mx).sum(axis=0) y = MatrixInverse()(mx).sum(axis=0)
yv = pytensor.gradient.Rop(y, mx, mv) yv = pytensor.gradient.Rop(y, mx, mv)
rop_f = function([mx, mv], yv) rop_f = function([mx, mv], yv)
...@@ -83,13 +84,11 @@ def test_transinv_to_invtrans(): ...@@ -83,13 +84,11 @@ def test_transinv_to_invtrans():
def test_generic_solve_to_solve_triangular(): def test_generic_solve_to_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = matrix("A") A = matrix("A")
x = matrix("x") x = matrix("x")
L = cholesky_lower(A) L = cholesky(A, lower=True)
U = cholesky_upper(A) U = cholesky(A, lower=False)
b1 = solve(L, x) b1 = solve(L, x)
b2 = solve(U, x) b2 = solve(U, x)
f = pytensor.function([A, x], b1) f = pytensor.function([A, x], b1)
...@@ -130,15 +129,15 @@ def test_matrix_inverse_solve(): ...@@ -130,15 +129,15 @@ def test_matrix_inverse_solve():
b = dmatrix("b") b = dmatrix("b")
node = matrix_inverse(A).dot(b).owner node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(None, node) [out] = inv_as_solve.transform(None, node)
assert isinstance(out.owner.op, Solve) assert isinstance(out.owner.op, Blockwise) and isinstance(
out.owner.op.core_op, Solve
)
@pytest.mark.parametrize("tag", ("lower", "upper", None)) @pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper")) @pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None)) @pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product): def test_cholesky_ldotlt(tag, cholesky_form, product):
cholesky = Cholesky(lower=(cholesky_form == "lower"))
transform_removes_chol = tag is not None and product == tag transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag transform_transposes = transform_removes_chol and cholesky_form != tag
...@@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): ...@@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else: else:
M = A M = A
C = cholesky(M) C = cholesky(M, lower=(cholesky_form == "lower"))
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
print(f.maker.fgraph.apply_nodes)
no_cholesky_in_graph = not any( no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
) )
......
...@@ -24,6 +24,7 @@ def test_vectorize_blockwise(): ...@@ -24,6 +24,7 @@ def test_vectorize_blockwise():
assert isinstance(vect_node.op, Blockwise) and isinstance( assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MatrixInverse vect_node.op.core_op, MatrixInverse
) )
assert vect_node.op.signature == ("(m,m)->(m,m)")
assert vect_node.inputs[0] is tns assert vect_node.inputs[0] is tns
# Useless blockwise # Useless blockwise
...@@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester): ...@@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester):
signature = "(m, m) -> (m, m)" signature = "(m, m) -> (m, m)"
class TestSolve(BlockwiseOpTester): class TestSolveVector(BlockwiseOpTester):
core_op = Solve(lower=True) core_op = Solve(lower=True, b_ndim=1)
signature = "(m, m),(m) -> (m)" signature = "(m, m),(m) -> (m)"
class TestSolveMatrix(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=2)
signature = "(m, m),(m, n) -> (m, n)"
...@@ -181,7 +181,7 @@ class TestSolveBase(utt.InferShapeTester): ...@@ -181,7 +181,7 @@ class TestSolveBase(utt.InferShapeTester):
( (
matrix, matrix,
functools.partial(tensor, dtype="floatX", shape=(None,) * 3), functools.partial(tensor, dtype="floatX", shape=(None,) * 3),
"`b` must be a matrix or a vector.*", "`b` must have 2 dims.*",
), ),
], ],
) )
...@@ -190,20 +190,20 @@ class TestSolveBase(utt.InferShapeTester): ...@@ -190,20 +190,20 @@ class TestSolveBase(utt.InferShapeTester):
with pytest.raises(ValueError, match=error_message): with pytest.raises(ValueError, match=error_message):
A = A_func() A = A_func()
b = b_func() b = b_func()
SolveBase()(A, b) SolveBase(b_ndim=2)(A, b)
def test__repr__(self): def test__repr__(self):
np.random.default_rng(utt.fetch_seed()) np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b = matrix() b = matrix()
y = SolveBase()(A, b) y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0" assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
class TestSolve(utt.InferShapeTester): class TestSolve(utt.InferShapeTester):
def test__init__(self): def test__init__(self):
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
Solve(assume_a="test") Solve(assume_a="test", b_ndim=2)
assert "is not a recognized matrix structure" in str(excinfo.value) assert "is not a recognized matrix structure" in str(excinfo.value)
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
...@@ -278,7 +278,7 @@ class TestSolve(utt.InferShapeTester): ...@@ -278,7 +278,7 @@ class TestSolve(utt.InferShapeTester):
if config.floatX == "float64": if config.floatX == "float64":
eps = 2e-8 eps = 2e-8
solve_op = Solve(assume_a=assume_a, lower=lower) solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
...@@ -349,19 +349,20 @@ class TestSolveTriangular(utt.InferShapeTester): ...@@ -349,19 +349,20 @@ class TestSolveTriangular(utt.InferShapeTester):
if config.floatX == "float64": if config.floatX == "float64":
eps = 2e-8 eps = 2e-8
solve_op = SolveTriangular(lower=lower) solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestCholeskySolve(utt.InferShapeTester): class TestCholeskySolve(utt.InferShapeTester):
def setup_method(self): def setup_method(self):
self.op_class = CholeskySolve self.op_class = CholeskySolve
self.op = CholeskySolve()
self.op_upper = CholeskySolve(lower=False)
super().setup_method() super().setup_method()
def test_repr(self): def test_repr(self):
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)" assert (
repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
)
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -369,7 +370,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -369,7 +370,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = matrix() b = matrix()
self._compile_and_check( self._compile_and_check(
[A, b], # pytensor.function inputs [A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs [self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs
# A must be square # A must be square
[ [
np.asarray(rng.random((5, 5)), dtype=config.floatX), np.asarray(rng.random((5, 5)), dtype=config.floatX),
...@@ -383,7 +384,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -383,7 +384,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = vector() b = vector()
self._compile_and_check( self._compile_and_check(
[A, b], # pytensor.function inputs [A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs [self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs
# A must be square # A must be square
[ [
np.asarray(rng.random((5, 5)), dtype=config.floatX), np.asarray(rng.random((5, 5)), dtype=config.floatX),
...@@ -397,10 +398,10 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -397,10 +398,10 @@ class TestCholeskySolve(utt.InferShapeTester):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b = matrix() b = matrix()
y = self.op(A, b) y = self.op_class(lower=True, b_ndim=2)(A, b)
cho_solve_lower_func = pytensor.function([A, b], y) cho_solve_lower_func = pytensor.function([A, b], y)
y = self.op_upper(A, b) y = self.op_class(lower=False, b_ndim=2)(A, b)
cho_solve_upper_func = pytensor.function([A, b], y) cho_solve_upper_func = pytensor.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
...@@ -435,12 +436,13 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -435,12 +436,13 @@ class TestCholeskySolve(utt.InferShapeTester):
A_val = np.eye(2) A_val = np.eye(2)
b_val = np.ones((2, 1)) b_val = np.ones((2, 1))
op = self.op_class(b_ndim=2)
# try all dtype combinations # try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes): for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = matrix(dtype=A_dtype) A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype) b = matrix(dtype=b_dtype)
x = self.op(A, b) x = op(A, b)
fn = function([A, b], x) fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论