提交 4c40efa2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Stabilize rewrites shouldn't target Dot22

上级 2aa4aad7
......@@ -23,7 +23,6 @@ from pytensor.tensor.basic import (
diag,
diagonal,
)
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
......@@ -103,12 +102,12 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize
@node_rewriter([Dot, Dot22])
@node_rewriter([Dot])
def inv_as_solve(fgraph, node):
"""
This utilizes a boolean `symmetric` tag on the matrices.
"""
if isinstance(node.op, Dot | Dot22):
if isinstance(node.op, Dot):
l, r = node.inputs
if (
l.owner
......@@ -280,7 +279,7 @@ def cholesky_ldotlt(fgraph, node):
A.owner is not None
and (
(
isinstance(A.owner.op, Dot | Dot22)
isinstance(A.owner.op, Dot)
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论