Unverified 提交 b2b7e287 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Fix bug with dummy output clients in `local_det_chol` rewrite (#393)

* check for dummy outputs in local_det_chol rewrite * add rewrite check to 2nd test case * fix test
上级 82aeefc7
...@@ -162,6 +162,8 @@ def local_det_chol(fgraph, node): ...@@ -162,6 +162,8 @@ def local_det_chol(fgraph, node):
if isinstance(node.op, Det): if isinstance(node.op, Det):
(x,) = node.inputs (x,) = node.inputs
for cl, xpos in fgraph.clients[x]: for cl, xpos in fgraph.clients[x]:
if cl == "output":
continue
if isinstance(cl.op, Cholesky): if isinstance(cl.op, Cholesky):
L = cl.outputs[0] L = cl.outputs[0]
return [prod(at.extract_diag(L) ** 2)] return [prod(at.extract_diag(L) ** 2)]
......
...@@ -11,7 +11,7 @@ from pytensor.compile import get_default_mode ...@@ -11,7 +11,7 @@ from pytensor.compile import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import _allclose from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse from pytensor.tensor.nlinalg import Det, MatrixInverse, matrix_inverse
from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.rewriting.linalg import inv_as_solve
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
from pytensor.tensor.type import dmatrix, matrix, vector from pytensor.tensor.type import dmatrix, matrix, vector
...@@ -202,3 +202,19 @@ def test_cholesky_ldotlt(tag, cholesky_form, product): ...@@ -202,3 +202,19 @@ def test_cholesky_ldotlt(tag, cholesky_form, product):
f(Av), f(Av),
) )
) )
def test_local_det_chol():
X = matrix("X")
L = at.linalg.cholesky(X)
det_X = at.linalg.det(X)
f = function([X], [L, det_X])
nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)
# This previously raised an error (issue #392)
f = function([X], [L, det_X, X])
nodes = f.maker.fgraph.toposort()
assert not any(isinstance(node, Det) for node in nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论