提交 5b9c07ec authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove useless ALL_REDUCE list

上级 e88117e6
......@@ -42,13 +42,8 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import (
All,
Any,
Dot,
FixedOpCAReduce,
NonZeroDimsCAReduce,
Prod,
ProdWithoutZeros,
Sum,
_conj,
add,
......@@ -1618,22 +1613,9 @@ def local_op_of_op(fgraph, node):
return [combined(node_inps.owner.inputs[0])]
ALL_REDUCE = [
CAReduce,
All,
Any,
Sum,
Prod,
ProdWithoutZeros,
*CAReduce.__subclasses__(),
*FixedOpCAReduce.__subclasses__(),
*NonZeroDimsCAReduce.__subclasses__(),
]
@register_canonicalize
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@node_rewriter(ALL_REDUCE)
@node_rewriter([CAReduce])
def local_reduce_join(fgraph, node):
"""
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
......@@ -1703,7 +1685,7 @@ def local_reduce_join(fgraph, node):
@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce")
@node_rewriter(ALL_REDUCE)
@node_rewriter([CAReduce])
def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a"""
(summed,) = node.inputs
......@@ -1715,7 +1697,7 @@ def local_useless_reduce(fgraph, node):
@register_canonicalize
@register_uncanonicalize
@register_specialize
@node_rewriter(ALL_REDUCE)
@node_rewriter([CAReduce])
def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions."""
(reduced,) = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论