提交 9238cc23 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Fix numba dispatch of Det and SlogDet returning non-arrays

上级 fc193d77
...@@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs): ...@@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def det(x): def det(x):
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype) return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
return det return det
...@@ -68,8 +68,8 @@ def numba_funcify_SLogDet(op, node, **kwargs): ...@@ -68,8 +68,8 @@ def numba_funcify_SLogDet(op, node, **kwargs):
def slogdet(x): def slogdet(x):
sign, det = np.linalg.slogdet(inputs_cast(x)) sign, det = np.linalg.slogdet(inputs_cast(x))
return ( return (
numba_basic.direct_cast(sign, out_dtype_1), np.array(sign).astype(out_dtype_1),
numba_basic.direct_cast(det, out_dtype_2), np.array(det).astype(out_dtype_2),
) )
return slogdet return slogdet
......
...@@ -11,68 +11,18 @@ from tests.link.numba.test_basic import compare_numba_and_py ...@@ -11,68 +11,18 @@ from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
@pytest.mark.parametrize( @pytest.mark.parametrize("dtype", ("float64", "int64"))
"x, exc", @pytest.mark.parametrize("op", (nlinalg.Det(), nlinalg.SLogDet()))
[ def test_Det_SLogDet(op, dtype):
( x = pt.matrix(dtype=dtype)
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_Det(x, exc):
x, test_x = x
g = nlinalg.Det()(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc) rng = np.random.default_rng([50, sum(map(ord, dtype))])
with cm: x_ = rng.random(size=(3, 3)).astype(dtype)
compare_numba_and_py( test_x = x_.T.dot(x_)
[x],
g,
[test_x],
)
g = op(x)
@pytest.mark.parametrize( compare_numba_and_py([x], g, [test_x])
"x, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_SLogDet(x, exc):
x, test_x = x
g = nlinalg.SLogDet()(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
# We were seeing some weird results in CI where the following two almost # We were seeing some weird results in CI where the following two almost
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论