Unverified 提交 96f753b0 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Remove incorrect `solve` usage in `psd_solve_with_chol` rewrite (#575)

* Use `solve_triangular` instead of in `psd_solve_with_chol` * Add unittest for `psd_solve_with_chol` * Specify `mode=FAST_RUN` in test * Relax `test_psd_solve_with_chol` `atol` and `rtol` for half-precision tests
上级 65967fe2
...@@ -215,8 +215,8 @@ def psd_solve_with_chol(fgraph, node): ...@@ -215,8 +215,8 @@ def psd_solve_with_chol(fgraph, node):
# N.B. this can be further reduced to a yet-unwritten cho_solve Op # N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the L matrix during the # __if__ no other Op makes use of the L matrix during the
# stabilization # stabilization
Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2) Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2) x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2)
return [x] return [x]
......
...@@ -241,6 +241,33 @@ def test_local_det_chol(): ...@@ -241,6 +241,33 @@ def test_local_det_chol():
assert not any(isinstance(node, Det) for node in nodes) assert not any(isinstance(node, Det) for node in nodes)
def test_psd_solve_with_chol():
X = matrix("X")
X.tag.psd = True
X_inv = pt.linalg.solve(X, pt.identity_like(X))
f = function([X], X_inv, mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Solve) for node in nodes)
assert any(isinstance(node.op, Cholesky) for node in nodes)
assert any(isinstance(node.op, SolveTriangular) for node in nodes)
# Numeric test
rng = np.random.default_rng(sum(map(ord, "test_psd_solve_with_chol")))
L = rng.normal(size=(5, 5)).astype(config.floatX)
X_psd = L @ L.T
X_psd_inv = f(X_psd)
assert_allclose(
X_psd_inv,
np.linalg.inv(X_psd),
atol=1e-4 if config.floatX == "float32" else 1e-8,
rtol=1e-4 if config.floatX == "float32" else 1e-8,
)
class TestBatchedVectorBSolveToMatrixBSolve: class TestBatchedVectorBSolveToMatrixBSolve:
rewrite_name = "batched_vector_b_solve_to_matrix_b_solve" rewrite_name = "batched_vector_b_solve_to_matrix_b_solve"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论