Unverified 提交 197069d1 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Implement numba overload for POTRF, LAPACK cholesky routine (#578)

* Implement numba overload for POTRF, LAPACK cholesky routine * Delete old numba_funcify_Cholesky * Refactor tests to include supported keywords and datatypes * Validate inputs and outputs of numba cholesky function * Raise on complex inputs * Change `cholesky` default for `check_finite` to `False` * Remove redundant dtype checks from numba linalg dispatchers * Add docstring to `numba_funcify_Cholesky` explaining why the overload is necessary.
上级 f737996d
......@@ -37,7 +37,7 @@ from pytensor.sparse import SparseTensorType
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Cholesky, Solve
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -809,41 +809,6 @@ def numba_funcify_Softplus(op, node, **kwargs):
return softplus
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
out_dtype = node.outputs[0].type.numpy_dtype
if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_njit
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba_njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret
return cholesky
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
......
......@@ -9,7 +9,7 @@ from scipy import linalg
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, SolveTriangular
_PTR = ctypes.POINTER
......@@ -25,6 +25,15 @@ _ptr_char = _PTR(_char)
_ptr_int = _PTR(_int)
@numba.core.extending.register_jitable
def _check_finite_matrix(a, func_name):
for v in np.nditer(a):
if not np.isfinite(v.item()):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input to " + func_name
)
@intrinsic
def val_to_dptr(typingctx, data):
def impl(context, builder, signature, args):
......@@ -177,6 +186,22 @@ class _LAPACK:
return functype(lapack_ptr)
@classmethod
def numba_xpotrf(cls, dtype):
"""
Called by scipy.linalg.cholesky
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "potrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # UPLO,
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # INFO
)
return functype(lapack_ptr)
def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False):
return linalg.solve_triangular(
......@@ -190,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
_check_scipy_linalg_matrix(A, "solve_triangular")
_check_scipy_linalg_matrix(B, "solve_triangular")
dtype = A.dtype
if str(dtype).startswith("complex"):
raise ValueError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
)
w_type = _get_underlying_float(dtype)
numba_trtrs = _LAPACK().numba_xtrtrs(dtype)
......@@ -249,8 +268,8 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
)
if B_is_1d:
return B_copy[..., 0]
return B_copy
return B_copy[..., 0], int_ptr_to_val(INFO)
return B_copy, int_ptr_to_val(INFO)
return impl
......@@ -262,19 +281,122 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
)
@numba_basic.numba_njit(inline="always")
def solve_triangular(a, b):
res = _solve_triangular(a, b, trans, lower, unit_diagonal)
if check_finite:
if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))):
raise ValueError(
"Non-numeric values (nan or inf) returned by solve_triangular"
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input A to solve_triangular"
)
if np.any(np.bitwise_or(np.isinf(b), np.isnan(b))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) in input b to solve_triangular"
)
res, info = _solve_triangular(a, b, trans, lower, unit_diagonal)
if info != 0:
raise np.linalg.LinAlgError(
"Singular matrix in input A to solve_triangular"
)
return res
return solve_triangular
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
return linalg.cholesky(
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
)
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_potrf = _LAPACK().numba_xpotrf(dtype)
def impl(A, lower=0, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)
if not overwrite_a:
A_copy = _copy_to_fortran_order(A)
else:
A_copy = A
numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
LDA,
INFO,
)
return A_copy, int_ptr_to_val(INFO)
return impl
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
"""
Overload scipy.linalg.cholesky with a numba function.
Note that np.linalg.cholesky is already implemented in numba, but it does not support additional keyword arguments.
In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
"""
lower = op.lower
overwrite_a = False
check_finite = op.check_finite
on_error = op.on_error
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cholesky in Numba mode"
)
@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
if check_finite:
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
raise np.linalg.LinAlgError(
"Non-numeric values (nan or inf) found in input to cholesky"
)
res, info = _cholesky(a, lower, overwrite_a, check_finite)
if on_error == "raise":
if info > 0:
raise np.linalg.LinAlgError(
"Input to cholesky is not positive definite"
)
if info < 0:
raise ValueError(
'LAPACK reported an illegal value in input on entry to "POTRF."'
)
else:
if info != 0:
res = np.full_like(res, np.nan)
return res
return nb_cholesky
@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype
......
......@@ -51,9 +51,10 @@ class Cholesky(Op):
__props__ = ("lower", "destructive", "on_error")
gufunc_signature = "(m,m)->(m,m)"
def __init__(self, *, lower=True, on_error="raise"):
def __init__(self, *, lower=True, check_finite=True, on_error="raise"):
self.lower = lower
self.destructive = False
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
......@@ -70,7 +71,9 @@ class Cholesky(Op):
x = inputs[0]
z = outputs[0]
try:
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype)
z[0] = scipy.linalg.cholesky(
x, lower=self.lower, check_finite=self.check_finite
).astype(x.dtype)
except scipy.linalg.LinAlgError:
if self.on_error == "raise":
raise
......@@ -129,8 +132,10 @@ class Cholesky(Op):
return [grad]
def cholesky(x, lower=True, on_error="raise"):
return Blockwise(Cholesky(lower=lower, on_error=on_error))(x)
def cholesky(x, lower=True, on_error="raise", check_finite=False):
return Blockwise(
Cholesky(lower=lower, on_error=on_error, check_finite=check_finite)
)(x)
class SolveBase(Op):
......
......@@ -14,57 +14,6 @@ from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower=lower)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"A, x, lower, exc",
[
......
......@@ -6,7 +6,10 @@ import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant, FunctionGraph
from tests.link.numba.test_basic import compare_numba_and_py
from tests.tensor.test_extra_ops import set_test_value
numba = pytest.importorskip("numba")
......@@ -99,11 +102,62 @@ def test_solve_triangular_raises_on_nan_inf(value):
b = np.full((5, 1), value)
with pytest.raises(
ValueError, match=re.escape("Non-numeric values (nan or inf) returned ")
np.linalg.LinAlgError,
match=re.escape("Non-numeric values"),
):
f(A_tri, b)
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
def test_numba_Cholesky(lower):
x = set_test_value(
pt.tensor(dtype=config.floatX, shape=(3, 3)),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
)
g = pt.linalg.cholesky(x, lower=lower)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
def test_numba_Cholesky_raises_on_nan_input():
test_value = rng.random(size=(3, 3)).astype(config.floatX)
test_value[0, 0] = np.nan
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
x = x.T.dot(x)
g = pt.linalg.cholesky(x, check_finite=True)
f = pytensor.function([x], g, mode="NUMBA")
with pytest.raises(np.linalg.LinAlgError, match=r"Non-numeric values"):
f(test_value)
@pytest.mark.parametrize("on_error", ["nan", "raise"])
def test_numba_Cholesky_raise_on(on_error):
test_value = rng.random(size=(3, 3)).astype(config.floatX)
x = pt.tensor(dtype=config.floatX, shape=(3, 3))
g = pt.linalg.cholesky(x, on_error=on_error)
f = pytensor.function([x], g, mode="NUMBA")
if on_error == "raise":
with pytest.raises(
np.linalg.LinAlgError, match=r"Input to cholesky is not positive definite"
):
f(test_value)
else:
assert np.all(np.isnan(f(test_value)))
def test_block_diag():
A = pt.matrix("A")
B = pt.matrix("B")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论