提交 06e7afea authored 作者: Mateusz Sokół's avatar Mateusz Sokół 提交者: Thomas Wiecki

Add slogdet for Numba

上级 65826e7e
...@@ -18,6 +18,7 @@ from pytensor.tensor.nlinalg import ( ...@@ -18,6 +18,7 @@ from pytensor.tensor.nlinalg import (
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull, QRFull,
SLogDet,
) )
...@@ -58,6 +59,25 @@ def numba_funcify_Det(op, node, **kwargs): ...@@ -58,6 +59,25 @@ def numba_funcify_Det(op, node, **kwargs):
return det return det
@numba_funcify.register(SLogDet)
def numba_funcify_SLogDet(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
@numba_basic.numba_njit
def slogdet(x):
sign, det = np.linalg.slogdet(inputs_cast(x))
return (
numba_basic.direct_cast(sign, out_dtype_1),
numba_basic.direct_cast(det, out_dtype_2),
)
return slogdet
@numba_funcify.register(Eig) @numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs): def numba_funcify_Eig(op, node, **kwargs):
......
...@@ -231,6 +231,39 @@ class Det(Op): ...@@ -231,6 +231,39 @@ class Det(Op):
det = Det() det = Det()
class SLogDet(Op):
"""
Compute the log determinant and its sign of the matrix. Input should be a square matrix.
"""
__props__ = ()
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
sign = scalar(dtype=x.dtype)
det = scalar(dtype=x.dtype)
return Apply(self, [x], [sign, det])
def perform(self, node, inputs, outputs):
(x,) = inputs
(sign, det) = outputs
try:
sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x))
except Exception:
print("Failed to compute determinant", x)
raise
def infer_shape(self, fgraph, node, shapes):
return [(), ()]
def __str__(self):
return "SLogDet"
slogdet = SLogDet()
class Eig(Op): class Eig(Op):
""" """
Compute the eigenvalues and right eigenvectors of a square array. Compute the eigenvalues and right eigenvectors of a square array.
......
...@@ -179,6 +179,41 @@ def test_Det(x, exc): ...@@ -179,6 +179,41 @@ def test_Det(x, exc):
) )
@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(
at.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
set_test_value(
at.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_SLogDet(x, exc):
g = nlinalg.SLogDet()(x)
g_fg = FunctionGraph(outputs=g)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
# We were seeing some weird results in CI where the following two almost # We were seeing some weird results in CI where the following two almost
# sign-swapped results were being return from Numba and Python, respectively. # sign-swapped results were being return from Numba and Python, respectively.
# The issue might be related to https://github.com/numba/numba/issues/4519. # The issue might be related to https://github.com/numba/numba/issues/4519.
......
...@@ -24,6 +24,7 @@ from pytensor.tensor.nlinalg import ( ...@@ -24,6 +24,7 @@ from pytensor.tensor.nlinalg import (
norm, norm,
pinv, pinv,
qr, qr,
slogdet,
svd, svd,
tensorinv, tensorinv,
tensorsolve, tensorsolve,
...@@ -280,6 +281,26 @@ def test_det_shape(): ...@@ -280,6 +281,26 @@ def test_det_shape():
assert tuple(det_shape.data) == () assert tuple(det_shape.data) == ()
def test_slogdet():
rng = np.random.default_rng(utt.fetch_seed())
r = rng.standard_normal((5, 5)).astype(config.floatX)
x = matrix()
f = pytensor.function([x], slogdet(x))
f_sign, f_det = f(r)
sign, det = np.linalg.slogdet(r)
assert np.equal(sign, f_sign)
assert np.allclose(det, f_det)
def test_slogdet_shape():
x = matrix()
sign, det = slogdet(x)
for shape in [sign.shape, det.shape]:
assert isinstance(shape, Constant)
assert tuple(shape.data) == ()
def test_trace(): def test_trace():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
x = matrix() x = matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论