Unverified 提交 f48068af authored 作者: Aidan Costello's avatar Aidan Costello 提交者: GitHub

Use LAPACK functions for `cho_solve`, `lu_factor`, `solve_triangular` (#1605)

* Use lapack instead of `scipy_linalg.cho_solve` * Use lapack instead of `scipy_linalg.lu_factor` * Use lapack instead of `scipy_linalg.solve_triangular` * Add empty test for lu_factor * Tidy imports * remove ndim check
上级 ac6c4e04
......@@ -7,7 +7,7 @@ from typing import Literal, cast
import numpy as np
import scipy.linalg as scipy_linalg
from numpy.exceptions import ComplexWarning
from scipy.linalg import get_lapack_funcs
from scipy.linalg import LinAlgError, LinAlgWarning, get_lapack_funcs
import pytensor
from pytensor import ifelse
......@@ -384,15 +384,28 @@ class CholeskySolve(SolveBase):
return Apply(self, [A, b], [out])
def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy_linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
)
c, b = inputs
(potrs,) = get_lapack_funcs(("potrs",), (c, b))
output_storage[0][0] = rval
if self.check_finite and not (np.isfinite(c).all() and np.isfinite(b).all()):
raise ValueError("array must not contain infs or NaNs")
if c.shape[0] != c.shape[1]:
raise ValueError("The factored matrix c is not square.")
if c.shape[1] != b.shape[0]:
raise ValueError(f"incompatible dimensions ({c.shape} and {b.shape})")
# Quick return for empty arrays
if b.size == 0:
output_storage[0][0] = np.empty_like(b, dtype=potrs.dtype)
return
x, info = potrs(c, b, lower=self.lower, overwrite_b=self.overwrite_b)
if info != 0:
raise ValueError(f"illegal value in {-info}th argument of internal potrs")
output_storage[0][0] = x
def L_op(self, *args, **kwargs):
# TODO: Base impl should work, let's try it
......@@ -696,9 +709,27 @@ class LUFactor(Op):
def perform(self, node, inputs, outputs):
A = inputs[0]
LU, p = scipy_linalg.lu_factor(
A, overwrite_a=self.overwrite_a, check_finite=self.check_finite
)
# Quick return for empty arrays
if A.size == 0:
outputs[0][0] = np.empty_like(A)
outputs[1][0] = np.array([], dtype=np.int32)
return
if self.check_finite and not np.isfinite(A).all():
raise ValueError("array must not contain infs or NaNs")
(getrf,) = get_lapack_funcs(("getrf",), (A,))
LU, p, info = getrf(A, overwrite_a=self.overwrite_a)
if info < 0:
raise ValueError(
f"illegal value in {-info}th argument of internal getrf (lu_factor)"
)
if info > 0:
warnings.warn(
f"Diagonal number {info} is exactly zero. Singular matrix.",
LinAlgWarning,
stacklevel=2,
)
outputs[0][0] = LU
outputs[1][0] = p
......@@ -865,15 +896,51 @@ class SolveTriangular(SolveBase):
def perform(self, node, inputs, outputs):
A, b = inputs
outputs[0][0] = scipy_linalg.solve_triangular(
A,
b,
lower=self.lower,
trans=0,
unit_diagonal=self.unit_diagonal,
check_finite=self.check_finite,
overwrite_b=self.overwrite_b,
)
if self.check_finite and not (np.isfinite(A).all() and np.isfinite(b).all()):
raise ValueError("array must not contain infs or NaNs")
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError("expected square matrix")
if A.shape[0] != b.shape[0]:
raise ValueError(f"shapes of a {A.shape} and b {b.shape} are incompatible")
(trtrs,) = get_lapack_funcs(("trtrs",), (A, b))
# Quick return for empty arrays
if b.size == 0:
outputs[0][0] = np.empty_like(b, dtype=trtrs.dtype)
return
if A.flags["F_CONTIGUOUS"]:
x, info = trtrs(
A,
b,
overwrite_b=self.overwrite_b,
lower=self.lower,
trans=0,
unitdiag=self.unit_diagonal,
)
else:
# transposed system is solved since trtrs expects Fortran ordering
x, info = trtrs(
A.T,
b,
overwrite_b=self.overwrite_b,
lower=not self.lower,
trans=1,
unitdiag=self.unit_diagonal,
)
if info > 0:
raise LinAlgError(
f"singular matrix: resolution failed at diagonal {info-1}"
)
elif info < 0:
raise ValueError(f"illegal value in {-info}-th argument of internal trtrs")
outputs[0][0] = x
def L_op(self, inputs, outputs, output_gradients):
res = super().L_op(inputs, outputs, output_gradients)
......
......@@ -513,6 +513,31 @@ class TestSolveTriangular(utt.InferShapeTester):
utt.verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)
def test_solve_triangular_empty(self):
rng = np.random.default_rng(utt.fetch_seed())
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("b", shape=(5, 0))
A_val = rng.random((5, 5)).astype(config.floatX)
b_empty = np.empty([5, 0], dtype=config.floatX)
A_func = functools.partial(self.A_func, lower=True, unit_diagonal=True)
x = solve_triangular(
A_func(A),
b,
lower=True,
trans=0,
unit_diagonal=True,
b_ndim=len((5, 0)),
)
f = function([A, b], x)
res = f(A_val, b_empty)
assert res.size == 0
assert res.dtype == config.floatX
class TestCholeskySolve(utt.InferShapeTester):
def setup_method(self):
......@@ -797,6 +822,18 @@ def test_lu_factor():
)
def test_lu_factor_empty():
A = matrix()
f = function([A], lu_factor(A))
A_empty = np.empty([0, 0], dtype=config.floatX)
LU, pt_p_idx = f(A_empty)
assert LU.size == 0
assert LU.dtype == config.floatX
assert pt_p_idx.size == 0
def test_cho_solve():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
......@@ -814,6 +851,21 @@ def test_cho_solve():
)
def test_cho_solve_empty():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = cho_solve((A, True), b)
cho_solve_lower_func = function([A, b], y)
A_empty = np.tril(np.asarray(rng.random((5, 5)), dtype=config.floatX))
b_empty = np.empty([5, 0], dtype=config.floatX)
res = cho_solve_lower_func(A_empty, b_empty)
assert res.size == 0
assert res.dtype == config.floatX
def test_expm():
rng = np.random.default_rng(utt.fetch_seed())
A = rng.standard_normal((5, 5)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论