提交 4c40efa2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Stabilize rewrites shouldn't target Dot22

上级 2aa4aad7
...@@ -23,7 +23,6 @@ from pytensor.tensor.basic import ( ...@@ -23,7 +23,6 @@ from pytensor.tensor.basic import (
diag, diag,
diagonal, diagonal,
) )
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
...@@ -103,12 +102,12 @@ def transinv_to_invtrans(fgraph, node): ...@@ -103,12 +102,12 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize @register_stabilize
@node_rewriter([Dot, Dot22]) @node_rewriter([Dot])
def inv_as_solve(fgraph, node): def inv_as_solve(fgraph, node):
""" """
This utilizes a boolean `symmetric` tag on the matrices. This utilizes a boolean `symmetric` tag on the matrices.
""" """
if isinstance(node.op, Dot | Dot22): if isinstance(node.op, Dot):
l, r = node.inputs l, r = node.inputs
if ( if (
l.owner l.owner
...@@ -280,7 +279,7 @@ def cholesky_ldotlt(fgraph, node): ...@@ -280,7 +279,7 @@ def cholesky_ldotlt(fgraph, node):
A.owner is not None A.owner is not None
and ( and (
( (
isinstance(A.owner.op, Dot | Dot22) isinstance(A.owner.op, Dot)
# This rewrite only applies to matrix Dot # This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2 and A.owner.inputs[0].type.ndim == 2
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论