Unverified 提交 197069d1 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement numba overload for POTRF, LAPACK cholesky routine (#578)

* Implement numba overload for POTRF, LAPACK cholesky routine * Delete old numba_funcify_Cholesky * Refactor tests to include supported keywords and datatypes * Validate inputs and outputs of numba cholesky function * Raise on complex inputs * Change `cholesky` default for `check_finite` to `False` * Remove redundant dtype checks from numba linalg dispatchers * Add docstring to `numba_funcify_Cholesky` explaining why the overload is necessary.
上级 f737996d
...@@ -37,7 +37,7 @@ from pytensor.sparse import SparseTensorType ...@@ -37,7 +37,7 @@ from pytensor.sparse import SparseTensorType
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Cholesky, Solve from pytensor.tensor.slinalg import Solve
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -809,41 +809,6 @@ def numba_funcify_Softplus(op, node, **kwargs): ...@@ -809,41 +809,6 @@ def numba_funcify_Softplus(op, node, **kwargs):
return softplus return softplus
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
out_dtype = node.outputs[0].type.numpy_dtype
if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_njit
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret
return cholesky
@numba_funcify.register(Solve) @numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs): def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a assume_a = op.assume_a
......
...@@ -9,7 +9,7 @@ from scipy import linalg ...@@ -9,7 +9,7 @@ from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular
_PTR = ctypes.POINTER _PTR = ctypes.POINTER
...@@ -25,6 +25,15 @@ _ptr_char = _PTR(_char) ...@@ -25,6 +25,15 @@ _ptr_char = _PTR(_char)
_ptr_int = _PTR(_int) _ptr_int = _PTR(_int)
@numba.core.extending.register_jitable
def _check_finite_matrix(a, func_name):
for v in np.nditer(a):
if not np.isfinite(v.item()):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input to " + func_name
)
@intrinsic @intrinsic
def val_to_dptr(typingctx, data): def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args): def impl(context, builder, signature, args):
...@@ -177,6 +186,22 @@ class _LAPACK: ...@@ -177,6 +186,22 @@ class _LAPACK:
return functype(lapack_ptr) return functype(lapack_ptr)
@classmethod
def numba_xpotrf(cls, dtype):
"""
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)
def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
return linalg.solve_triangular( return linalg.solve_triangular(
...@@ -190,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): ...@@ -190,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
_check_scipy_linalg_matrix(A, "solve_triangular") _check_scipy_linalg_matrix(A, "solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular") _check_scipy_linalg_matrix(B, "solve_triangular")
dtype = A.dtype dtype = A.dtype
if str(dtype).startswith("complex"):
raise ValueError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
)
w_type = _get_underlying_float(dtype) w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
...@@ -249,8 +268,8 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): ...@@ -249,8 +268,8 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
) )
if B_is_1d: if B_is_1d:
return B_copy[..., 0] return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy return B_copy, int_ptr_to_val(INFO)
return impl return impl
...@@ -262,19 +281,122 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -262,19 +281,122 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
unit_diagonal = op.unit_diagonal unit_diagonal = op.unit_diagonal
check_finite = op.check_finite check_finite = op.check_finite
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def solve_triangular(a, b): def solve_triangular(a, b):
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
if check_finite: if check_finite:
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))): if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise ValueError( raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) returned by solve_triangular" "Non-numeric values (nan or inf) in input A to solve_triangular"
) )
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res, info = _solve_triangular(a, b, trans, lower, unit_diagonal)
if info != 0:
raise np.linalg.LinAlgError(
"Singular matrix in input A to solve_triangular"
)
return res return res
return solve_triangular return solve_triangular
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
)
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)
return A_copy, int_ptr_to_val(INFO)
return impl
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
"""
Overload scipy.linalg.cholesky with a numba function.
Note that np.linalg.cholesky is already implemented in numba, but it does not support additional keyword arguments.
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
"""
lower = op.lower
overwrite_a = False
check_finite = op.check_finite
on_error = op.on_error
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cholesky in Numba mode"
)
@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res, info = _cholesky(a, lower, overwrite_a, check_finite)
if on_error == "raise":
if info > 0:
raise np.linalg.LinAlgError(
"Input to cholesky is not positive definite"
)
if info < 0:
raise ValueError(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else:
if info != 0:
res = np.full_like(res, np.nan)
return res
return nb_cholesky
@numba_funcify.register(BlockDiagonal) @numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs): def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype dtype = node.outputs[0].dtype
......
...@@ -51,9 +51,10 @@ class Cholesky(Op): ...@@ -51,9 +51,10 @@ class Cholesky(Op):
__props__ = ("lower", "destructive", "on_error") __props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)" gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, on_error="raise"): def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
self.lower = lower self.lower = lower
self.destructive = False self.destructive = False
self.check_finite = check_finite
if on_error not in ("raise", "nan"): if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"') raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error self.on_error = on_error
...@@ -70,7 +71,9 @@ class Cholesky(Op): ...@@ -70,7 +71,9 @@ class Cholesky(Op):
x = inputs[0] x = inputs[0]
z = outputs[0] z = outputs[0]
try: try:
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) z[0] = scipy.linalg.cholesky(
x, lower=self.lower, check_finite=self.check_finite
).astype(x.dtype)
except scipy.linalg.LinAlgError: except scipy.linalg.LinAlgError:
if self.on_error == "raise": if self.on_error == "raise":
raise raise
...@@ -129,8 +132,10 @@ class Cholesky(Op): ...@@ -129,8 +132,10 @@ class Cholesky(Op):
return [grad] return [grad]
def cholesky(x, lower=True, on_error="raise"): def cholesky(x, lower=True, on_error="raise", check_finite=False):
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x) return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)
class SolveBase(Op): class SolveBase(Op):
......
...@@ -14,57 +14,6 @@ from tests.link.numba.test_basic import compare_numba_and_py, set_test_value ...@@ -14,57 +14,6 @@ from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower=lower)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"A, x, lower, exc", "A, x, lower, exc",
[ [
......
...@@ -6,7 +6,10 @@ import pytest ...@@ -6,7 +6,10 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, FunctionGraph
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.test_extra_ops import set_test_value
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
...@@ -99,11 +102,62 @@ def test_solve_triangular_raises_on_nan_inf(value): ...@@ -99,11 +102,62 @@ def test_solve_triangular_raises_on_nan_inf(value):
b = np.full((5, 1), value) b = np.full((5, 1), value)
with pytest.raises( with pytest.raises(
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") np.linalg.LinAlgError,
match=re.escape("Non-numeric values"),
): ):
f(A_tri, b) f(A_tri, b)
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky(lower):
x = set_test_value(
pt.tensor(dtype=config.floatX, shape=(3, 3)),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
)
g = pt.linalg.cholesky(x, lower=lower)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
def test_numba_Cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(config.floatX)
test_value[0, 0] = np.nan
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)
@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_numba_Cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(config.floatX)
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
def test_block_diag(): def test_block_diag():
A = pt.matrix("A") A = pt.matrix("A")
B = pt.matrix("B") B = pt.matrix("B")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论