提交 a3eed0b4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Blockwise some linalg Ops by default

上级 7fb4e70a
...@@ -3764,7 +3764,7 @@ def stacklists(arg): ...@@ -3764,7 +3764,7 @@ def stacklists(arg):
return arg return arg
def swapaxes(y, axis1, axis2): def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor." "Swap the axes of a tensor."
y = as_tensor_variable(y) y = as_tensor_variable(y)
ndim = y.ndim ndim = y.ndim
......
...@@ -10,11 +10,13 @@ from pytensor.graph.op import Op ...@@ -10,11 +10,13 @@ from pytensor.graph.op import Op
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import math as tm from pytensor.tensor import math as tm
from pytensor.tensor.basic import as_tensor_variable, extract_diag from pytensor.tensor.basic import as_tensor_variable, extract_diag
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
class MatrixPinv(Op): class MatrixPinv(Op):
__props__ = ("hermitian",) __props__ = ("hermitian",)
gufunc_signature = "(m,n)->(n,m)"
def __init__(self, hermitian): def __init__(self, hermitian):
self.hermitian = hermitian self.hermitian = hermitian
...@@ -75,7 +77,7 @@ def pinv(x, hermitian=False): ...@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
solve op. solve op.
""" """
return MatrixPinv(hermitian=hermitian)(x) return Blockwise(MatrixPinv(hermitian=hermitian))(x)
class MatrixInverse(Op): class MatrixInverse(Op):
...@@ -93,6 +95,8 @@ class MatrixInverse(Op): ...@@ -93,6 +95,8 @@ class MatrixInverse(Op):
""" """
__props__ = () __props__ = ()
gufunc_signature = "(m,m)->(m,m)"
gufunc_spec = ("numpy.linalg.inv", 1, 1)
def __init__(self): def __init__(self):
pass pass
...@@ -150,7 +154,7 @@ class MatrixInverse(Op): ...@@ -150,7 +154,7 @@ class MatrixInverse(Op):
return shapes return shapes
inv = matrix_inverse = MatrixInverse() inv = matrix_inverse = Blockwise(MatrixInverse())
def matrix_dot(*args): def matrix_dot(*args):
...@@ -181,6 +185,8 @@ class Det(Op): ...@@ -181,6 +185,8 @@ class Det(Op):
""" """
__props__ = () __props__ = ()
gufunc_signature = "(m,m)->()"
gufunc_spec = ("numpy.linalg.det", 1, 1)
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -209,7 +215,7 @@ class Det(Op): ...@@ -209,7 +215,7 @@ class Det(Op):
return "Det" return "Det"
det = Det() det = Blockwise(Det())
class SLogDet(Op): class SLogDet(Op):
...@@ -218,6 +224,8 @@ class SLogDet(Op): ...@@ -218,6 +224,8 @@ class SLogDet(Op):
""" """
__props__ = () __props__ = ()
gufunc_signature = "(m, m)->(),()"
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -242,7 +250,7 @@ class SLogDet(Op): ...@@ -242,7 +250,7 @@ class SLogDet(Op):
return "SLogDet" return "SLogDet"
slogdet = SLogDet() slogdet = Blockwise(SLogDet())
class Eig(Op): class Eig(Op):
...@@ -252,6 +260,8 @@ class Eig(Op): ...@@ -252,6 +260,8 @@ class Eig(Op):
""" """
__props__: Tuple[str, ...] = () __props__: Tuple[str, ...] = ()
gufunc_signature = "(m,m)->(m),(m,m)"
gufunc_spec = ("numpy.linalg.eig", 1, 2)
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -270,7 +280,7 @@ class Eig(Op): ...@@ -270,7 +280,7 @@ class Eig(Op):
return [(n,), (n, n)] return [(n,), (n, n)]
eig = Eig() eig = Blockwise(Eig())
class Eigh(Eig): class Eigh(Eig):
......
import logging import logging
from typing import cast
from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import basic as at from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, Prod, log, prod from pytensor.tensor.math import Dot, Prod, log, prod
from pytensor.tensor.nlinalg import Det, MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse, det
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve, solve_triangular
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def is_matrix_transpose(x: TensorVariable) -> bool:
"""Check if a variable corresponds to a transpose of the last two axes"""
node = x.owner
if (
node
and isinstance(node.op, DimShuffle)
and not (node.op.drop or node.op.augment)
):
[inp] = node.inputs
ndims = inp.type.ndim
if ndims < 2:
return False
transpose_order = tuple(range(ndims - 2)) + (ndims - 1, ndims - 2)
return cast(bool, node.op.new_order == transpose_order)
return False
def _T(x: TensorVariable) -> TensorVariable:
"""Matrix transpose for potentially higher dimensionality tensors"""
return swapaxes(x, -1, -2)
@register_canonicalize @register_canonicalize
@node_rewriter([DimShuffle]) @node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node): def transinv_to_invtrans(fgraph, node):
if isinstance(node.op, DimShuffle): if is_matrix_transpose(node.outputs[0]):
if node.op.new_order == (1, 0):
(A,) = node.inputs (A,) = node.inputs
if A.owner: if (
if isinstance(A.owner.op, MatrixInverse): A.owner
and isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, MatrixInverse)
):
(X,) = A.owner.inputs (X,) = A.owner.inputs
return [A.owner.op(node.op(X))] return [A.owner.op(node.op(X))]
...@@ -37,43 +63,72 @@ def inv_as_solve(fgraph, node): ...@@ -37,43 +63,72 @@ def inv_as_solve(fgraph, node):
""" """
if isinstance(node.op, (Dot, Dot22)): if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs l, r = node.inputs
if l.owner and isinstance(l.owner.op, MatrixInverse): if (
l.owner
and isinstance(l.owner.op, Blockwise)
and isinstance(l.owner.op.core_op, MatrixInverse)
):
return [solve(l.owner.inputs[0], r)] return [solve(l.owner.inputs[0], r)]
if r.owner and isinstance(r.owner.op, MatrixInverse): if (
r.owner
and isinstance(r.owner.op, Blockwise)
and isinstance(r.owner.op.core_op, MatrixInverse)
):
x = r.owner.inputs[0] x = r.owner.inputs[0]
if getattr(x.tag, "symmetric", None) is True: if getattr(x.tag, "symmetric", None) is True:
return [solve(x, l.T).T] return [_T(solve(x, _T(l)))]
else: else:
return [solve(x.T, l.T).T] return [_T(solve(_T(x), _T(l)))]
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@node_rewriter([Solve]) @node_rewriter([Blockwise])
def generic_solve_to_solve_triangular(fgraph, node): def generic_solve_to_solve_triangular(fgraph, node):
""" """
If any solve() is applied to the output of a cholesky op, then If any solve() is applied to the output of a cholesky op, then
replace it with a triangular solve. replace it with a triangular solve.
""" """
if isinstance(node.op, Solve): if isinstance(node.op.core_op, Solve):
if node.op.core_op.assume_a == "gen":
A, b = node.inputs # result is solution Ax=b A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, Cholesky):
if A.owner.op.lower:
return [SolveTriangular(lower=True)(A, b)]
else:
return [SolveTriangular(lower=False)(A, b)]
if ( if (
A.owner A.owner
and isinstance(A.owner.op, DimShuffle) and isinstance(A.owner.op, Blockwise)
and A.owner.op.new_order == (1, 0) and isinstance(A.owner.op.core_op, Cholesky)
): ):
if A.owner.op.core_op.lower:
return [
solve_triangular(
A, b, lower=True, b_ndim=node.op.core_op.b_ndim
)
]
else:
return [
solve_triangular(
A, b, lower=False, b_ndim=node.op.core_op.b_ndim
)
]
if is_matrix_transpose(A):
(A_T,) = A.owner.inputs (A_T,) = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, Cholesky): if (
A_T.owner
and isinstance(A_T.owner.op, Blockwise)
and isinstance(A_T.owner.op, Cholesky)
):
if A_T.owner.op.lower: if A_T.owner.op.lower:
return [SolveTriangular(lower=False)(A, b)] return [
solve_triangular(
A, b, lower=False, b_ndim=node.op.core_op.b_ndim
)
]
else: else:
return [SolveTriangular(lower=True)(A, b)] return [
solve_triangular(
A, b, lower=True, b_ndim=node.op.core_op.b_ndim
)
]
@register_canonicalize @register_canonicalize
...@@ -81,34 +136,33 @@ def generic_solve_to_solve_triangular(fgraph, node): ...@@ -81,34 +136,33 @@ def generic_solve_to_solve_triangular(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([DimShuffle]) @node_rewriter([DimShuffle])
def no_transpose_symmetric(fgraph, node): def no_transpose_symmetric(fgraph, node):
if isinstance(node.op, DimShuffle): if is_matrix_transpose(node.outputs[0]):
x = node.inputs[0] x = node.inputs[0]
if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True: if getattr(x.tag, "symmetric", None):
if node.op.new_order == [1, 0]:
return [x] return [x]
@register_stabilize @register_stabilize
@node_rewriter([Solve]) @node_rewriter([Blockwise])
def psd_solve_with_chol(fgraph, node): def psd_solve_with_chol(fgraph, node):
""" """
This utilizes a boolean `psd` tag on matrices. This utilizes a boolean `psd` tag on matrices.
""" """
if isinstance(node.op, Solve): if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2:
A, b = node.inputs # result is solution Ax=b A, b = node.inputs # result is solution Ax=b
if getattr(A.tag, "psd", None) is True: if getattr(A.tag, "psd", None) is True:
L = cholesky(A) L = cholesky(A)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op # N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the the L matrix during the # __if__ no other Op makes use of the L matrix during the
# stabilization # stabilization
Li_b = Solve(assume_a="sym", lower=True)(L, b) Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2)
x = Solve(assume_a="sym", lower=False)(L.T, Li_b) x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2)
return [x] return [x]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@node_rewriter([Cholesky]) @node_rewriter([Blockwise])
def cholesky_ldotlt(fgraph, node): def cholesky_ldotlt(fgraph, node):
""" """
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
...@@ -116,7 +170,7 @@ def cholesky_ldotlt(fgraph, node): ...@@ -116,7 +170,7 @@ def cholesky_ldotlt(fgraph, node):
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
""" """
if not isinstance(node.op, Cholesky): if not isinstance(node.op.core_op, Cholesky):
return return
A = node.inputs[0] A = node.inputs[0]
...@@ -128,45 +182,40 @@ def cholesky_ldotlt(fgraph, node): ...@@ -128,45 +182,40 @@ def cholesky_ldotlt(fgraph, node):
# cholesky(dot(L,L.T)) case # cholesky(dot(L,L.T)) case
if ( if (
getattr(l.tag, "lower_triangular", False) getattr(l.tag, "lower_triangular", False)
and r.owner and is_matrix_transpose(r)
and isinstance(r.owner.op, DimShuffle)
and r.owner.op.new_order == (1, 0)
and r.owner.inputs[0] == l and r.owner.inputs[0] == l
): ):
if node.op.lower: if node.op.core_op.lower:
return [l] return [l]
return [r] return [r]
# cholesky(dot(U.T,U)) case # cholesky(dot(U.T,U)) case
if ( if (
getattr(r.tag, "upper_triangular", False) getattr(r.tag, "upper_triangular", False)
and l.owner and is_matrix_transpose(l)
and isinstance(l.owner.op, DimShuffle)
and l.owner.op.new_order == (1, 0)
and l.owner.inputs[0] == r and l.owner.inputs[0] == r
): ):
if node.op.lower: if node.op.core_op.lower:
return [l] return [l]
return [r] return [r]
@register_stabilize @register_stabilize
@register_specialize @register_specialize
@node_rewriter([Det]) @node_rewriter([det])
def local_det_chol(fgraph, node): def local_det_chol(fgraph, node):
""" """
If we have det(X) and there is already an L=cholesky(X) If we have det(X) and there is already an L=cholesky(X)
floating around, then we can use prod(diag(L)) to get the determinant. floating around, then we can use prod(diag(L)) to get the determinant.
""" """
if isinstance(node.op, Det):
(x,) = node.inputs (x,) = node.inputs
for cl, xpos in fgraph.clients[x]: for cl, xpos in fgraph.clients[x]:
if cl == "output": if cl == "output":
continue continue
if isinstance(cl.op, Cholesky): if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
L = cl.outputs[0] L = cl.outputs[0]
return [prod(at.extract_diag(L) ** 2)] return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
@register_canonicalize @register_canonicalize
...@@ -177,7 +226,6 @@ def local_log_prod_sqr(fgraph, node): ...@@ -177,7 +226,6 @@ def local_log_prod_sqr(fgraph, node):
""" """
This utilizes a boolean `positive` tag on matrices. This utilizes a boolean `positive` tag on matrices.
""" """
if node.op == log:
(x,) = node.inputs (x,) = node.inputs
if x.owner and isinstance(x.owner.op, Prod): if x.owner and isinstance(x.owner.op, Prod):
# we cannot always make this substitution because # we cannot always make this substitution because
......
import logging import logging
import typing import typing
import warnings import warnings
from typing import TYPE_CHECKING, Literal, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
import scipy.linalg import scipy.linalg
...@@ -13,6 +13,7 @@ from pytensor.graph.op import Op ...@@ -13,6 +13,7 @@ from pytensor.graph.op import Op
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import math as atm from pytensor.tensor import math as atm
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import matrix_dot from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
...@@ -48,6 +49,7 @@ class Cholesky(Op): ...@@ -48,6 +49,7 @@ class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also # TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "destructive", "on_error") __props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, on_error="raise"): def __init__(self, *, lower=True, on_error="raise"):
self.lower = lower self.lower = lower
...@@ -109,7 +111,7 @@ class Cholesky(Op): ...@@ -109,7 +111,7 @@ class Cholesky(Op):
def conjugate_solve_triangular(outer, inner): def conjugate_solve_triangular(outer, inner):
"""Computes L^{-T} P L^{-1} for lower-triangular L.""" """Computes L^{-T} P L^{-1} for lower-triangular L."""
solve_upper = SolveTriangular(lower=False) solve_upper = SolveTriangular(lower=False, b_ndim=2)
return solve_upper(outer.T, solve_upper(outer.T, inner.T).T) return solve_upper(outer.T, solve_upper(outer.T, inner.T).T)
s = conjugate_solve_triangular( s = conjugate_solve_triangular(
...@@ -128,7 +130,7 @@ class Cholesky(Op): ...@@ -128,7 +130,7 @@ class Cholesky(Op):
def cholesky(x, lower=True, on_error="raise"): def cholesky(x, lower=True, on_error="raise"):
return Cholesky(lower=lower, on_error=on_error)(x) return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
class SolveBase(Op): class SolveBase(Op):
...@@ -137,6 +139,7 @@ class SolveBase(Op): ...@@ -137,6 +139,7 @@ class SolveBase(Op):
__props__ = ( __props__ = (
"lower", "lower",
"check_finite", "check_finite",
"b_ndim",
) )
def __init__( def __init__(
...@@ -144,9 +147,16 @@ class SolveBase(Op): ...@@ -144,9 +147,16 @@ class SolveBase(Op):
*, *,
lower=False, lower=False,
check_finite=True, check_finite=True,
b_ndim,
): ):
self.lower = lower self.lower = lower
self.check_finite = check_finite self.check_finite = check_finite
assert b_ndim in (1, 2)
self.b_ndim = b_ndim
if b_ndim == 1:
self.gufunc_signature = "(m,m),(m)->(m)"
else:
self.gufunc_signature = "(m,m),(m,n)->(m,n)"
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
pass pass
...@@ -157,8 +167,8 @@ class SolveBase(Op): ...@@ -157,8 +167,8 @@ class SolveBase(Op):
if A.ndim != 2: if A.ndim != 2:
raise ValueError(f"`A` must be a matrix; got {A.type} instead.") raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
if b.ndim not in (1, 2): if b.ndim != self.b_ndim:
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.") raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")
# Infer dtype by solving the most simple case with 1x1 matrices # Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve( o_dtype = scipy.linalg.solve(
...@@ -209,6 +219,16 @@ class SolveBase(Op): ...@@ -209,6 +219,16 @@ class SolveBase(Op):
return [A_bar, b_bar] return [A_bar, b_bar]
def _default_b_ndim(b, b_ndim):
if b_ndim is not None:
assert b_ndim in (1, 2)
return b_ndim
b = as_tensor_variable(b)
if b_ndim is None:
return min(b.ndim, 2) # By default assume the core case is a matrix
class CholeskySolve(SolveBase): class CholeskySolve(SolveBase):
def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs.setdefault("lower", True) kwargs.setdefault("lower", True)
...@@ -228,7 +248,7 @@ class CholeskySolve(SolveBase): ...@@ -228,7 +248,7 @@ class CholeskySolve(SolveBase):
raise NotImplementedError() raise NotImplementedError()
def cho_solve(c_and_lower, b, *, check_finite=True): def cho_solve(c_and_lower, b, *, check_finite=True, b_ndim: Optional[int] = None):
"""Solve the linear equations A x = b, given the Cholesky factorization of A. """Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters Parameters
...@@ -241,9 +261,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True): ...@@ -241,9 +261,15 @@ def cho_solve(c_and_lower, b, *, check_finite=True):
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
""" """
A, lower = c_and_lower A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b) b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
CholeskySolve(lower=lower, check_finite=check_finite, b_ndim=b_ndim)
)(A, b)
class SolveTriangular(SolveBase): class SolveTriangular(SolveBase):
...@@ -254,6 +280,7 @@ class SolveTriangular(SolveBase): ...@@ -254,6 +280,7 @@ class SolveTriangular(SolveBase):
"unit_diagonal", "unit_diagonal",
"lower", "lower",
"check_finite", "check_finite",
"b_ndim",
) )
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs): def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
...@@ -291,6 +318,7 @@ def solve_triangular( ...@@ -291,6 +318,7 @@ def solve_triangular(
lower: bool = False, lower: bool = False,
unit_diagonal: bool = False, unit_diagonal: bool = False,
check_finite: bool = True, check_finite: bool = True,
b_ndim: Optional[int] = None,
) -> TensorVariable: ) -> TensorVariable:
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix. """Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
...@@ -314,12 +342,19 @@ def solve_triangular( ...@@ -314,12 +342,19 @@ def solve_triangular(
Whether to check that the input matrices contain only finite numbers. Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
""" """
return SolveTriangular( b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
SolveTriangular(
lower=lower, lower=lower,
trans=trans, trans=trans,
unit_diagonal=unit_diagonal, unit_diagonal=unit_diagonal,
check_finite=check_finite, check_finite=check_finite,
b_ndim=b_ndim,
)
)(a, b) )(a, b)
...@@ -332,6 +367,7 @@ class Solve(SolveBase): ...@@ -332,6 +367,7 @@ class Solve(SolveBase):
"assume_a", "assume_a",
"lower", "lower",
"check_finite", "check_finite",
"b_ndim",
) )
def __init__(self, *, assume_a="gen", **kwargs): def __init__(self, *, assume_a="gen", **kwargs):
...@@ -352,7 +388,15 @@ class Solve(SolveBase): ...@@ -352,7 +388,15 @@ class Solve(SolveBase):
) )
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): def solve(
a,
b,
*,
assume_a="gen",
lower=False,
check_finite=True,
b_ndim: Optional[int] = None,
):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix. """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the If the data matrix is known to be a particular type then supplying the
...@@ -375,9 +419,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): ...@@ -375,9 +419,9 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
Parameters Parameters
---------- ----------
a : (N, N) array_like a : (..., N, N) array_like
Square input data Square input data
b : (N, NRHS) array_like b : (..., N, NRHS) array_like
Input data for the right hand side. Input data for the right hand side.
lower : bool, optional lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default If True, only the data contained in the lower triangle of `a`. Default
...@@ -388,11 +432,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True): ...@@ -388,11 +432,18 @@ def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
(crashes, non-termination) if the inputs do contain infinities or NaNs. (crashes, non-termination) if the inputs do contain infinities or NaNs.
assume_a : str, optional assume_a : str, optional
Valid entries are explained above. Valid entries are explained above.
b_ndim : int
Whether the core case of b is a vector (1) or matrix (2).
This will influence how batched dimensions are interpreted.
""" """
return Solve( b_ndim = _default_b_ndim(b, b_ndim)
return Blockwise(
Solve(
lower=lower, lower=lower,
check_finite=check_finite, check_finite=check_finite,
assume_a=assume_a, assume_a=assume_a,
b_ndim=b_ndim,
)
)(a, b) )(a, b)
......
...@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc): ...@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
], ],
) )
def test_Solve(A, x, lower, exc): def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower=lower)(A, x) g = slinalg.Solve(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) g_fg = FunctionGraph(outputs=g)
...@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc): ...@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
], ],
) )
def test_SolveTriangular(A, x, lower, exc): def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower=lower)(A, x) g = slinalg.SolveTriangular(lower=lower, b_ndim=1)(A, x)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) g_fg = FunctionGraph(outputs=g)
......
...@@ -9,11 +9,12 @@ from pytensor import function ...@@ -9,11 +9,12 @@ from pytensor import function
from pytensor import tensor as at from pytensor import tensor as at
from pytensor.compile import get_default_mode from pytensor.compile import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
from pytensor.tensor.type import dmatrix, matrix, vector from pytensor.tensor.type import dmatrix, matrix, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.test_rop import break_op from tests.test_rop import break_op
...@@ -23,7 +24,7 @@ def test_rop_lop(): ...@@ -23,7 +24,7 @@ def test_rop_lop():
mx = matrix("mx") mx = matrix("mx")
mv = matrix("mv") mv = matrix("mv")
v = vector("v") v = vector("v")
y = matrix_inverse(mx).sum(axis=0) y = MatrixInverse()(mx).sum(axis=0)
yv = pytensor.gradient.Rop(y, mx, mv) yv = pytensor.gradient.Rop(y, mx, mv)
rop_f = function([mx, mv], yv) rop_f = function([mx, mv], yv)
...@@ -83,13 +84,11 @@ def test_transinv_to_invtrans(): ...@@ -83,13 +84,11 @@ def test_transinv_to_invtrans():
def test_generic_solve_to_solve_triangular(): def test_generic_solve_to_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = matrix("A") A = matrix("A")
x = matrix("x") x = matrix("x")
L = cholesky_lower(A) L = cholesky(A, lower=True)
U = cholesky_upper(A) U = cholesky(A, lower=False)
b1 = solve(L, x) b1 = solve(L, x)
b2 = solve(U, x) b2 = solve(U, x)
f = pytensor.function([A, x], b1) f = pytensor.function([A, x], b1)
...@@ -130,15 +129,15 @@ def test_matrix_inverse_solve(): ...@@ -130,15 +129,15 @@ def test_matrix_inverse_solve():
b = dmatrix("b") b = dmatrix("b")
node = matrix_inverse(A).dot(b).owner node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(None, node) [out] = inv_as_solve.transform(None, node)
assert isinstance(out.owner.op, Solve) assert isinstance(out.owner.op, Blockwise) and isinstance(
out.owner.op.core_op, Solve
)
@pytest.mark.parametrize("tag", ("lower", "upper", None)) @pytest.mark.parametrize("tag", ("lower", "upper", None))
@pytest.mark.parametrize("cholesky_form", ("lower", "upper")) @pytest.mark.parametrize("cholesky_form", ("lower", "upper"))
@pytest.mark.parametrize("product", ("lower", "upper", None)) @pytest.mark.parametrize("product", ("lower", "upper", None))
def test_cholesky_ldotlt(tag, cholesky_form, product): def test_cholesky_ldotlt(tag, cholesky_form, product):
cholesky = Cholesky(lower=(cholesky_form == "lower"))
transform_removes_chol = tag is not None and product == tag transform_removes_chol = tag is not None and product == tag
transform_transposes = transform_removes_chol and cholesky_form != tag transform_transposes = transform_removes_chol and cholesky_form != tag
...@@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): ...@@ -153,11 +152,9 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
else: else:
M = A M = A
C = cholesky(M) C = cholesky(M, lower=(cholesky_form == "lower"))
f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt")) f = pytensor.function([A], C, mode=get_default_mode().including("cholesky_ldotlt"))
print(f.maker.fgraph.apply_nodes)
no_cholesky_in_graph = not any( no_cholesky_in_graph = not any(
isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes isinstance(node.op, Cholesky) for node in f.maker.fgraph.apply_nodes
) )
......
...@@ -24,6 +24,7 @@ def test_vectorize_blockwise(): ...@@ -24,6 +24,7 @@ def test_vectorize_blockwise():
assert isinstance(vect_node.op, Blockwise) and isinstance( assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MatrixInverse vect_node.op.core_op, MatrixInverse
) )
assert vect_node.op.signature == ("(m,m)->(m,m)")
assert vect_node.inputs[0] is tns assert vect_node.inputs[0] is tns
# Useless blockwise # Useless blockwise
...@@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester): ...@@ -253,6 +254,11 @@ class TestMatrixInverse(MatrixOpBlockwiseTester):
signature = "(m, m) -> (m, m)" signature = "(m, m) -> (m, m)"
class TestSolve(BlockwiseOpTester): class TestSolveVector(BlockwiseOpTester):
core_op = Solve(lower=True) core_op = Solve(lower=True, b_ndim=1)
signature = "(m, m),(m) -> (m)" signature = "(m, m),(m) -> (m)"
class TestSolveMatrix(BlockwiseOpTester):
core_op = Solve(lower=True, b_ndim=2)
signature = "(m, m),(m, n) -> (m, n)"
...@@ -181,7 +181,7 @@ class TestSolveBase(utt.InferShapeTester): ...@@ -181,7 +181,7 @@ class TestSolveBase(utt.InferShapeTester):
( (
matrix, matrix,
functools.partial(tensor, dtype="floatX", shape=(None,) * 3), functools.partial(tensor, dtype="floatX", shape=(None,) * 3),
"`b` must be a matrix or a vector.*", "`b` must have 2 dims.*",
), ),
], ],
) )
...@@ -190,20 +190,20 @@ class TestSolveBase(utt.InferShapeTester): ...@@ -190,20 +190,20 @@ class TestSolveBase(utt.InferShapeTester):
with pytest.raises(ValueError, match=error_message): with pytest.raises(ValueError, match=error_message):
A = A_func() A = A_func()
b = b_func() b = b_func()
SolveBase()(A, b) SolveBase(b_ndim=2)(A, b)
def test__repr__(self): def test__repr__(self):
np.random.default_rng(utt.fetch_seed()) np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b = matrix() b = matrix()
y = SolveBase()(A, b) y = SolveBase(b_ndim=2)(A, b)
assert y.__repr__() == "SolveBase{lower=False, check_finite=True}.0" assert y.__repr__() == "SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
class TestSolve(utt.InferShapeTester): class TestSolve(utt.InferShapeTester):
def test__init__(self): def test__init__(self):
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
Solve(assume_a="test") Solve(assume_a="test", b_ndim=2)
assert "is not a recognized matrix structure" in str(excinfo.value) assert "is not a recognized matrix structure" in str(excinfo.value)
@pytest.mark.parametrize("b_shape", [(5, 1), (5,)]) @pytest.mark.parametrize("b_shape", [(5, 1), (5,)])
...@@ -278,7 +278,7 @@ class TestSolve(utt.InferShapeTester): ...@@ -278,7 +278,7 @@ class TestSolve(utt.InferShapeTester):
if config.floatX == "float64": if config.floatX == "float64":
eps = 2e-8 eps = 2e-8
solve_op = Solve(assume_a=assume_a, lower=lower) solve_op = Solve(assume_a=assume_a, lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
...@@ -349,19 +349,20 @@ class TestSolveTriangular(utt.InferShapeTester): ...@@ -349,19 +349,20 @@ class TestSolveTriangular(utt.InferShapeTester):
if config.floatX == "float64": if config.floatX == "float64":
eps = 2e-8 eps = 2e-8
solve_op = SolveTriangular(lower=lower) solve_op = SolveTriangular(lower=lower, b_ndim=1 if n is None else 2)
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps) utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
class TestCholeskySolve(utt.InferShapeTester): class TestCholeskySolve(utt.InferShapeTester):
def setup_method(self): def setup_method(self):
self.op_class = CholeskySolve self.op_class = CholeskySolve
self.op = CholeskySolve()
self.op_upper = CholeskySolve(lower=False)
super().setup_method() super().setup_method()
def test_repr(self): def test_repr(self):
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)" assert (
repr(CholeskySolve(lower=True, b_ndim=1))
== "CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
)
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -369,7 +370,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -369,7 +370,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = matrix() b = matrix()
self._compile_and_check( self._compile_and_check(
[A, b], # pytensor.function inputs [A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs [self.op_class(b_ndim=2)(A, b)], # pytensor.function outputs
# A must be square # A must be square
[ [
np.asarray(rng.random((5, 5)), dtype=config.floatX), np.asarray(rng.random((5, 5)), dtype=config.floatX),
...@@ -383,7 +384,7 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -383,7 +384,7 @@ class TestCholeskySolve(utt.InferShapeTester):
b = vector() b = vector()
self._compile_and_check( self._compile_and_check(
[A, b], # pytensor.function inputs [A, b], # pytensor.function inputs
[self.op(A, b)], # pytensor.function outputs [self.op_class(b_ndim=1)(A, b)], # pytensor.function outputs
# A must be square # A must be square
[ [
np.asarray(rng.random((5, 5)), dtype=config.floatX), np.asarray(rng.random((5, 5)), dtype=config.floatX),
...@@ -397,10 +398,10 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -397,10 +398,10 @@ class TestCholeskySolve(utt.InferShapeTester):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = matrix() A = matrix()
b = matrix() b = matrix()
y = self.op(A, b) y = self.op_class(lower=True, b_ndim=2)(A, b)
cho_solve_lower_func = pytensor.function([A, b], y) cho_solve_lower_func = pytensor.function([A, b], y)
y = self.op_upper(A, b) y = self.op_class(lower=False, b_ndim=2)(A, b)
cho_solve_upper_func = pytensor.function([A, b], y) cho_solve_upper_func = pytensor.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX) b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
...@@ -435,12 +436,13 @@ class TestCholeskySolve(utt.InferShapeTester): ...@@ -435,12 +436,13 @@ class TestCholeskySolve(utt.InferShapeTester):
A_val = np.eye(2) A_val = np.eye(2)
b_val = np.ones((2, 1)) b_val = np.ones((2, 1))
op = self.op_class(b_ndim=2)
# try all dtype combinations # try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes): for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = matrix(dtype=A_dtype) A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype) b = matrix(dtype=b_dtype)
x = self.op(A, b) x = op(A, b)
fn = function([A, b], x) fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype)) x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论