提交 9238cc23 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Fix numba dispatch of Det and SlogDet returning non-arrays

上级 fc193d77
......@@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs):
@numba_basic.numba_njit(inline="always")
def det(x):
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
return det
......@@ -68,8 +68,8 @@ def numba_funcify_SLogDet(op, node, **kwargs):
def slogdet(x):
sign, det = np.linalg.slogdet(inputs_cast(x))
return (
numba_basic.direct_cast(sign, out_dtype_1),
numba_basic.direct_cast(det, out_dtype_2),
np.array(sign).astype(out_dtype_1),
np.array(det).astype(out_dtype_2),
)
return slogdet
......
......@@ -11,68 +11,18 @@ from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"x, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_Det(x, exc):
x, test_x = x
g = nlinalg.Det()(x)
@pytest.mark.parametrize("dtype", ("float64", "int64"))
@pytest.mark.parametrize("op", (nlinalg.Det(), nlinalg.SLogDet()))
def test_Det_SLogDet(op, dtype):
x = pt.matrix(dtype=dtype)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
rng = np.random.default_rng([50, sum(map(ord, dtype))])
x_ = rng.random(size=(3, 3)).astype(dtype)
test_x = x_.T.dot(x_)
g = op(x)
@pytest.mark.parametrize(
"x, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_SLogDet(x, exc):
x, test_x = x
g = nlinalg.SLogDet()(x)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
compare_numba_and_py([x], g, [test_x])
# We were seeing some weird results in CI where the following two almost
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论