提交 a3eed0b4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Blockwise some linalg Ops by default

上级 7fb4e70a
......@@ -3764,7 +3764,7 @@ def stacklists(arg):
return arg
def swapaxes(y, axis1, axis2):
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor."
y = as_tensor_variable(y)
ndim = y.ndim
......
......@@ -10,11 +10,13 @@ from pytensor.graph.op import Op
from pytensor.tensor import basic as at
from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
class MatrixPinv(Op):
__props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"
def __init__(self, hermitian):
self.hermitian = hermitian
......@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
solve op.
"""
return MatrixPinv(hermitian=hermitian)(x)
return Blockwise(MatrixPinv(hermitian=hermitian))(x)
class MatrixInverse(Op):
......@@ -93,6 +95,8 @@ class MatrixInverse(Op):
"""
__props__ = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)
def __init__(self):
pass
......@@ -150,7 +154,7 @@ class MatrixInverse(Op):
return shapes
inv = matrix_inverse = MatrixInverse()
inv = matrix_inverse = Blockwise(MatrixInverse())
def matrix_dot(*args):
......@@ -181,6 +185,8 @@ class Det(Op):
"""
__props__ = ()
gufunc_signature = "(m,m)->()"
gufunc_spec = ("numpy.linalg.det", 1, 1)
def make_node(self, x):
x = as_tensor_variable(x)
......@@ -209,7 +215,7 @@ class Det(Op):
return "Det"
det = Det()
det = Blockwise(Det())
class SLogDet(Op):
......@@ -218,6 +224,8 @@ class SLogDet(Op):
"""
__props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)
def make_node(self, x):
x = as_tensor_variable(x)
......@@ -242,7 +250,7 @@ class SLogDet(Op):
return "SLogDet"
slogdet = SLogDet()
slogdet = Blockwise(SLogDet())
class Eig(Op):
......@@ -252,6 +260,8 @@ class Eig(Op):
"""
__props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)
def make_node(self, x):
x = as_tensor_variable(x)
......@@ -270,7 +280,7 @@ class Eig(Op):
return [(n,), (n, n)]
eig = Eig()
eig = Blockwise(Eig())
class Eigh(Eig):
......
import logging
from typing import cast
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import basic as at
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, log, prod
from pytensor.tensor.nlinalg import Det, MatrixInverse
from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular
logger = logging.getLogger(__name__)
def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
if (
node
and isinstance(node.op, DimShuffle)
and not (node.op.drop or node.op.augment)
):
[inp] = node.inputs
ndims = inp.type.ndim
if ndims < 2:
return False
transpose_order = tuple(range(ndims - 2)) + (ndims - 1, ndims - 2)
return cast(bool, node.op.new_order == transpose_order)
return False
def _T(x: TensorVariable) -> TensorVariable:
"""Matrix transpose for potentially higher dimensionality tensors"""
return swapaxes(x, -1, -2)
@register_canonicalize
@node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node):
if isinstance(node.op, DimShuffle):
if node.op.new_order == (1, 0):
(A,) = node.inputs
if A.owner:
if isinstance(A.owner.op, MatrixInverse):
(X,) = A.owner.inputs
return [A.owner.op(node.op(X))]
if is_matrix_transpose(node.outputs[0]):
(A,) = node.inputs
if (
A.owner
and isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, MatrixInverse)
):
(X,) = A.owner.inputs
return [A.owner.op(node.op(X))]
@register_stabilize
......@@ -37,43 +63,72 @@ def inv_as_solve(fgraph, node):
"""
if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs
if l.owner and isinstance(l.owner.op, MatrixInverse):
if (
l.owner
and isinstance(l.owner.op, Blockwise)
and isinstance(l.owner.op.core_op, MatrixInverse)
):
return [solve(l.owner.inputs[0], r)]
if r.owner and isinstance(r.owner.op, MatrixInverse):
if (
r.owner
and isinstance(r.owner.op, Blockwise)
and isinstance(r.owner.op.core_op, MatrixInverse)
):
x = r.owner.inputs[0]
if getattr(x.tag, "symmetric", None) is True:
return [solve(x, l.T).T]
return [_T(solve(x, _T(l)))]
else:
return [solve(x.T, l.T).T]
return [_T(solve(_T(x), _T(l)))]
@register_stabilize
@register_canonicalize
@node_rewriter([Solve])
@node_rewriter([Blockwise])
def generic_solve_to_solve_triangular(fgraph, node):
"""
If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if isinstance(node.op, Solve):
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [SolveTriangular(lower=True)(A, b)]
else:
return [SolveTriangular(lower=False)(A, b)]
if (
A.owner
and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)
):
(A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
if A_T.owner.op.lower:
return [SolveTriangular(lower=False)(A, b)]
if isinstance(node.op.core_op, Solve):
if node.op.core_op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b
if (
A.owner
and isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, Cholesky)
):
if A.owner.op.core_op.lower:
return [
solve_triangular(
A, b, lower=True, b_ndim=node.op.core_op.b_ndim
)
]
else:
return [SolveTriangular(lower=True)(A, b)]
return [
solve_triangular(
A, b, lower=False, b_ndim=node.op.core_op.b_ndim
)
]
if is_matrix_transpose(A):
(A_T,) = A.owner.inputs
if (
A_T.owner
and isinstance(A_T.owner.op, Blockwise)
and isinstance(A_T.owner.op, Cholesky)
):
if A_T.owner.op.lower:
return [
solve_triangular(
A, b, lower=False, b_ndim=node.op.core_op.b_ndim
)
]
else:
return [
solve_triangular(
A, b, lower=True, b_ndim=node.op.core_op.b_ndim
)
]
@register_canonicalize
......@@ -81,34 +136,33 @@ def generic_solve_to_solve_triangular(fgraph, node):
@register_specialize
@node_rewriter([DimShuffle])
def no_transpose_symmetric(fgraph, node):
if isinstance(node.op, DimShuffle):
if is_matrix_transpose(node.outputs[0]):
x = node.inputs[0]
if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True:
if node.op.new_order == [1, 0]:
return [x]
if getattr(x.tag, "symmetric", None):
return [x]
@register_stabilize
@node_rewriter([Solve])
@node_rewriter([Blockwise])
def psd_solve_with_chol(fgraph, node):
"""
This utilizes a boolean `psd` tag on matrices.
"""
if isinstance(node.op, Solve):
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2:
A, b = node.inputs # result is solution Ax=b
if getattr(A.tag, "psd", None) is True:
L = cholesky(A)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the the L matrix during the
# __if__ no other Op makes use of the L matrix during the
# stabilization
Li_b = Solve(assume_a="sym", lower=True)(L, b)
x = Solve(assume_a="sym", lower=False)(L.T, Li_b)
Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2)
x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2)
return [x]
@register_canonicalize
@register_stabilize
@node_rewriter([Cholesky])
@node_rewriter([Blockwise])
def cholesky_ldotlt(fgraph, node):
"""
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
......@@ -116,7 +170,7 @@ def cholesky_ldotlt(fgraph, node):
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
"""
if not isinstance(node.op, Cholesky):
if not isinstance(node.op.core_op, Cholesky):
return
A = node.inputs[0]
......@@ -128,45 +182,40 @@ def cholesky_ldotlt(fgraph, node):
# cholesky(dot(L,L.T)) case
if (
getattr(l.tag, "lower_triangular", False)
and r.owner
and isinstance(r.owner.op, DimShuffle)
and r.owner.op.new_order == (1, 0)
and is_matrix_transpose(r)
and r.owner.inputs[0] == l
):
if node.op.lower:
if node.op.core_op.lower:
return [l]
return [r]
# cholesky(dot(U.T,U)) case
if (
getattr(r.tag, "upper_triangular", False)
and l.owner
and isinstance(l.owner.op, DimShuffle)
and l.owner.op.new_order == (1, 0)
and is_matrix_transpose(l)
and l.owner.inputs[0] == r
):
if node.op.lower:
if node.op.core_op.lower:
return [l]
return [r]
@register_stabilize
@register_specialize
@node_rewriter([Det])
@node_rewriter([det])
def local_det_chol(fgraph, node):
"""
If we have det(X) and there is already an L=cholesky(X)
floating around, then we can use prod(diag(L)) to get the determinant.
"""
if isinstance(node.op, Det):
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Cholesky):
L = cl.outputs[0]
return [prod(at.extract_diag(L) ** 2)]
(x,) = node.inputs
for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
L = cl.outputs[0]
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
@register_canonicalize
......@@ -177,16 +226,15 @@ def local_log_prod_sqr(fgraph, node):
"""
This utilizes a boolean `positive` tag on matrices.
"""
if node.op == log:
(x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod):
# we cannot always make this substitution because
# the prod might include negative terms
p = x.owner.inputs[0]
# p is the matrix we're reducing with prod
if getattr(p.tag, "positive", None) is True:
return [log(p).sum(axis=x.owner.op.axis)]
# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
(x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod):
# we cannot always make this substitution because
# the prod might include negative terms
p = x.owner.inputs[0]
# p is the matrix we're reducing with prod
if getattr(p.tag, "positive", None) is True:
return [log(p).sum(axis=x.owner.op.axis)]
# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
import logging
import typing
import warnings
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
import scipy.linalg
......@@ -13,6 +13,7 @@ from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor import math as atm
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector
......@@ -48,6 +49,7 @@ class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, on_error="raise"):
self.lower = lower
......@@ -109,7 +111,7 @@ class Cholesky(Op):
def conjugate_solve_triangular(outer, inner):
"""Computes L^{-T} P L^{-1} for lower-triangular L."""
solve_upper = SolveTriangular(lower=False)
solve_upper = SolveTriangular(lower=False, b_ndim=2)
return solve_upper(outer.T, solve_upper(outer.T, inner.T).T)
s = conjugate_solve_triangular(
......@@ -128,7 +130,7 @@ class Cholesky(Op):
def cholesky(x, lower=True, on_error="raise"):
return Cholesky(lower=lower, on_error=on_error)(x)
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
class SolveBase(Op):
......@@ -137,6 +139,7 @@ class SolveBase(Op):
__props__ = (
"lower",
"check_finite",
"b_ndim",
)
def __init__(
......@@ -144,9 +147,16 @@ class SolveBase(Op):
*,
lower=False,
check_finite=True,
b_ndim,
):
self.lower = lower
self.check_finite = check_finite
assert b_ndim in (1, 2)
self.b_ndim = b_ndim
if b_ndim == 1:
self.gufunc_signature = "(m,m),(m)->(m)"
else:
self.gufunc_signature = "(m,m),(m,n)->(m,n)"
def perform(self, node, inputs, outputs):
pass
......@@ -157,8 +167,8 @@ class SolveBase(Op):
if A.ndim != 2:
raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
if b.ndim not in (1, 2):
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.")
if b.ndim != self.b_ndim:
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve(
......@@ -209,6 +219,16 @@ class SolveBase(Op):
return [A_bar, b_bar]
def _default_b_ndim(b, b_ndim):
if b_ndim is not None:
assert b_ndim in (1, 2)
return b_ndim
b = as_tensor_variable(b)
if b_ndim is None:
return min(b.ndim, 2) # By default assume the core case is a matrix
class CholeskySolve(SolveBase):
def __init__(self, **kwargs):
kwargs.setdefault("lower", True)
......@@ -228,7 +248,7 @@ class CholeskySolve(SolveBase):
raise NotImplementedError()
def cho_solve(c_and_lower, b, *, check_finite=True):
def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
......@@ -241,9 +261,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True):
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
CholeskySolve(lower=lower, check_finite=check_finite, b_ndim=b_ndim)
)(A, b)
class SolveTriangular(SolveBase):
......@@ -254,6 +280,7 @@ class SolveTriangular(SolveBase):
"unit_diagonal",
"lower",
"check_finite",
"b_ndim",
)
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
......@@ -291,6 +318,7 @@ def solve_triangular(
lower: bool = False,
unit_diagonal: bool = False,
check_finite: bool = True,
b_ndim: Optional[int] = None,
) -> TensorVariable:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
......@@ -314,12 +342,19 @@ def solve_triangular(
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
return SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
SolveTriangular(
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
b_ndim=b_ndim,
)
)(a, b)
......@@ -332,6 +367,7 @@ class Solve(SolveBase):
"assume_a",
"lower",
"check_finite",
"b_ndim",
)
def __init__(self, *, assume_a="gen", **kwargs):
......@@ -352,7 +388,15 @@ class Solve(SolveBase):
)
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
def solve(
a,
b,
*,
assume_a="gen",
lower=False,
check_finite=True,
b_ndim: Optional[int] = None,
):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
......@@ -375,9 +419,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
Parameters
----------
a : (N, N) array_like
a : (..., N, N) array_like
Square input data
b : (N, NRHS) array_like
b : (..., N, NRHS) array_like
Input data for the right hand side.
lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default
......@@ -388,11 +432,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
(crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional
Valid entries are explained above.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
"""
return Solve(
lower=lower,
check_finite=check_finite,
assume_a=assume_a,
b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
Solve(
lower=lower,
check_finite=check_finite,
assume_a=assume_a,
b_ndim=b_ndim,
)
)(a, b)
......
......@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower=lower)(A, x)
g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
],
)
def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower=lower)(A, x)
g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......
......@@ -9,11 +9,12 @@ from pytensor import function
from pytensor import tensor as at
from pytensor.compile import get_default_mode
from pytensor.configdefaults import config
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt
from tests.test_rop import break_op
......@@ -23,7 +24,7 @@ def test_rop_lop():
mx = matrix("mx")
mv = matrix("mv")
v = vector("v")
y = matrix_inverse(mx).sum(axis=0)
y = MatrixInverse()(mx).sum(axis=0)
yv = pytensor.gradient.Rop(y, mx, mv)
rop_f = function([mx, mv], yv)
......@@ -83,13 +84,11 @@ def test_transinv_to_invtrans():
def test_generic_solve_to_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = matrix("A")
x = matrix("x")
L = cholesky_lower(A)
U = cholesky_upper(A)
L = cholesky(A, lower=True)
U = cholesky(A, lower=False)
b1 = solve(L, x)
b2 = solve(U, x)
f = pytensor.function([A, x], b1)
......@@ -130,15 +129,15 @@ def test_matrix_inverse_solve():
b = dmatrix("b")
node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(None, node)
assert isinstance(out.owner.op, Solve)
assert isinstance(out.owner.op, Blockwise) and isinstance(
out.owner.op.core_op, Solve
)
@pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product):
cholesky = Cholesky(lower=(cholesky_form == "lower"))
transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag
......@@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else:
M = A
C = cholesky(M)
C = cholesky(M, lower=(cholesky_form == "lower"))
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
print(f.maker.fgraph.apply_nodes)
no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
)
......
......@@ -24,6 +24,7 @@ def test_vectorize_blockwise():
assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MatrixInverse
)
assert vect_node.op.signature == ("(m,m)->(m,m)")
assert vect_node.inputs[0] is tns
# Useless blockwise
......@@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester):
signature = "(m, m) -> (m, m)"
class TestSolve(BlockwiseOpTester):
core_op = Solve(lower=True)
class TestSolveVector(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=1)
signature = "(m, m),(m) -> (m)"
class TestSolveMatrix(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=2)
signature = "(m, m),(m, n) -> (m, n)"
......@@ -181,7 +181,7 @@ class TestSolveBase(utt.InferShapeTester):
(
matrix,
functools.partial(tensor, dtype="floatX", shape=(None,) * 3),
"`b` must be a matrix or a vector.*",
"`b` must have 2 dims.*",
),
],
)
......@@ -190,20 +190,20 @@ class TestSolveBase(utt.InferShapeTester):
with pytest.raises(ValueError, match=error_message):
A = A_func()
b = b_func()
SolveBase()(A, b)
SolveBase(b_ndim=2)(A, b)
def test__repr__(self):
np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = SolveBase()(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0"
y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
class TestSolve(utt.InferShapeTester):
def test__init__(self):
with pytest.raises(ValueError) as excinfo:
Solve(assume_a="test")
Solve(assume_a="test", b_ndim=2)
assert "is not a recognized matrix structure" in str(excinfo.value)
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
......@@ -278,7 +278,7 @@ class TestSolve(utt.InferShapeTester):
if config.floatX == "float64":
eps = 2e-8
solve_op = Solve(assume_a=assume_a, lower=lower)
solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
......@@ -349,19 +349,20 @@ class TestSolveTriangular(utt.InferShapeTester):
if config.floatX == "float64":
eps = 2e-8
solve_op = SolveTriangular(lower=lower)
solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestCholeskySolve(utt.InferShapeTester):
def setup_method(self):
self.op_class = CholeskySolve
self.op = CholeskySolve()
self.op_upper = CholeskySolve(lower=False)
super().setup_method()
def test_repr(self):
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)"
assert (
repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
)
def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed())
......@@ -369,7 +370,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = matrix()
self._compile_and_check(
[A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs
[self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs
# A must be square
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
......@@ -383,7 +384,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = vector()
self._compile_and_check(
[A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs
[self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs
# A must be square
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
......@@ -397,10 +398,10 @@ class TestCholeskySolve(utt.InferShapeTester):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = self.op(A, b)
y = self.op_class(lower=True, b_ndim=2)(A, b)
cho_solve_lower_func = pytensor.function([A, b], y)
y = self.op_upper(A, b)
y = self.op_class(lower=False, b_ndim=2)(A, b)
cho_solve_upper_func = pytensor.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
......@@ -435,12 +436,13 @@ class TestCholeskySolve(utt.InferShapeTester):
A_val = np.eye(2)
b_val = np.ones((2, 1))
op = self.op_class(b_ndim=2)
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype)
x = self.op(A, b)
x = op(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论