Unverified 提交 4b6a4440 authored 作者: theorashid's avatar theorashid 提交者: GitHub

Fix bug in nlinalg.slogdet and expose it in linalg module (#807)

上级 086323fa
...@@ -246,7 +246,7 @@ class SLogDet(Op): ...@@ -246,7 +246,7 @@ class SLogDet(Op):
(x,) = inputs (x,) = inputs
(sign, det) = outputs (sign, det) = outputs
try: try:
sign[0], det[0] = (z.astype(x.dtype) for z in np.linalg.slogdet(x)) sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x))
except Exception: except Exception:
print("Failed to compute determinant", x) print("Failed to compute determinant", x)
raise raise
...@@ -1186,6 +1186,7 @@ __all__ = [ ...@@ -1186,6 +1186,7 @@ __all__ = [
"lstsq", "lstsq",
"matrix_power", "matrix_power",
"norm", "norm",
"slogdet",
"tensorinv", "tensorinv",
"tensorsolve", "tensorsolve",
"kron", "kron",
......
...@@ -387,6 +387,11 @@ def test_slogdet(): ...@@ -387,6 +387,11 @@ def test_slogdet():
sign, det = np.linalg.slogdet(r) sign, det = np.linalg.slogdet(r)
assert np.equal(sign, f_sign) assert np.equal(sign, f_sign)
assert np.allclose(det, f_det) assert np.allclose(det, f_det)
# check numpy array types is returned
# see https://github.com/pymc-devs/pytensor/issues/799
sign, logdet = slogdet(x)
det = sign * pytensor.tensor.exp(logdet)
assert_array_almost_equal(det.eval({x: [[1]]}), np.array(1.0))
def test_trace(): def test_trace():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论