Unverified 提交 236e50d3 authored 作者: Aidan Costello's avatar Aidan Costello 提交者: GitHub

Use lapack func instead of `scipy.linalg.cholesky` (#1487)

* Use lapack func instead of `scipy.linalg.cholesky` * Now skips 2D checks in perform * Updated the default arguments for `check_finite` to false to match documentation * Add benchmark test case * Refactor out _cholesky helper, add empty test * Remove array and `potrf` copies * Update test_cholesky_raises_on_nan_input
上级 7886cf83
......@@ -37,7 +37,7 @@ class Cholesky(Op):
self,
*,
lower: bool = True,
check_finite: bool = True,
check_finite: bool = False,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
......@@ -67,29 +67,55 @@ class Cholesky(Op):
def perform(self, node, inputs, outputs):
[x] = inputs
[out] = outputs
try:
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
if self.overwrite_a and x.flags["C_CONTIGUOUS"]:
out[0] = scipy_linalg.cholesky(
x.T,
lower=not self.lower,
check_finite=self.check_finite,
overwrite_a=True,
).T
else:
out[0] = scipy_linalg.cholesky(
x,
lower=self.lower,
check_finite=self.check_finite,
overwrite_a=self.overwrite_a,
)
except scipy_linalg.LinAlgError:
if self.on_error == "raise":
raise
(potrf,) = scipy_linalg.get_lapack_funcs(("potrf",), (x,))
# Quick return for square empty array
if x.size == 0:
out[0] = np.empty_like(x, dtype=potrf.dtype)
return
if self.check_finite and not np.isfinite(x).all():
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=potrf.dtype)
return
else:
raise ValueError("array must not contain infs or NaNs")
# Squareness check
if x.shape[0] != x.shape[1]:
raise ValueError(
"Input array is expected to be square but has " f"the shape: {x.shape}."
)
# Scipy cholesky only makes use of overwrite_a when it is F_CONTIGUOUS
# If we have a `C_CONTIGUOUS` array we transpose to benefit from it
c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"]
if c_contiguous_input:
x = x.T
lower = not self.lower
overwrite_a = True
else:
lower = self.lower
overwrite_a = self.overwrite_a
c, info = potrf(x, lower=lower, overwrite_a=overwrite_a, clean=True)
if info != 0:
if self.on_error == "nan":
out[0] = np.full(x.shape, np.nan, dtype=node.outputs[0].type.dtype)
elif info > 0:
raise scipy_linalg.LinAlgError(
f"{info}-th leading minor of the array is not positive definite"
)
elif info < 0:
raise ValueError(
f"LAPACK reported an illegal value in {-info}-th argument "
f'on entry to "POTRF".'
)
else:
# Transpose result if input was transposed
out[0] = c.T if c_contiguous_input else c
def L_op(self, inputs, outputs, gradients):
"""
......@@ -201,7 +227,9 @@ def cholesky(
"""
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)
class SolveBase(Op):
......
......@@ -465,7 +465,7 @@ def test_cholesky_raises_on_nan_input():
x = pt.tensor(dtype=floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
......
......@@ -74,6 +74,26 @@ def test_cholesky():
check_upper_triangular(pd, ch_f)
def test_cholesky_performance(benchmark):
rng = np.random.default_rng(utt.fetch_seed())
r = rng.standard_normal((10, 10)).astype(config.floatX)
pd = np.dot(r, r.T)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
benchmark(ch_f, pd)
def test_cholesky_empty():
empty = np.empty([0, 0], dtype=config.floatX)
x = matrix()
chol = cholesky(x)
ch_f = function([x], chol)
ch = ch_f(empty)
assert ch.size == 0
assert ch.dtype == config.floatX
def test_cholesky_indef():
x = matrix()
mat = np.array([[1, 0.2], [0.2, -2]]).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论