提交 5b9c07ec authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove useless ALL_REDUCE list

上级 e88117e6
...@@ -42,13 +42,8 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise ...@@ -42,13 +42,8 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import ( from pytensor.tensor.math import (
All,
Any,
Dot, Dot,
FixedOpCAReduce,
NonZeroDimsCAReduce,
Prod, Prod,
ProdWithoutZeros,
Sum, Sum,
_conj, _conj,
add, add,
...@@ -1618,22 +1613,9 @@ def local_op_of_op(fgraph, node): ...@@ -1618,22 +1613,9 @@ def local_op_of_op(fgraph, node):
return [combined(node_inps.owner.inputs[0])] return [combined(node_inps.owner.inputs[0])]
ALL_REDUCE = [
CAReduce,
All,
Any,
Sum,
Prod,
ProdWithoutZeros,
*CAReduce.__subclasses__(),
*FixedOpCAReduce.__subclasses__(),
*NonZeroDimsCAReduce.__subclasses__(),
]
@register_canonicalize @register_canonicalize
@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce @register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce
@node_rewriter(ALL_REDUCE) @node_rewriter([CAReduce])
def local_reduce_join(fgraph, node): def local_reduce_join(fgraph, node):
""" """
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
...@@ -1703,7 +1685,7 @@ def local_reduce_join(fgraph, node): ...@@ -1703,7 +1685,7 @@ def local_reduce_join(fgraph, node):
@register_infer_shape @register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce") @register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce") @register_useless("local_cut_useless_reduce")
@node_rewriter(ALL_REDUCE) @node_rewriter([CAReduce])
def local_useless_reduce(fgraph, node): def local_useless_reduce(fgraph, node):
"""Sum(a, axis=[]) -> a""" """Sum(a, axis=[]) -> a"""
(summed,) = node.inputs (summed,) = node.inputs
...@@ -1715,7 +1697,7 @@ def local_useless_reduce(fgraph, node): ...@@ -1715,7 +1697,7 @@ def local_useless_reduce(fgraph, node):
@register_canonicalize @register_canonicalize
@register_uncanonicalize @register_uncanonicalize
@register_specialize @register_specialize
@node_rewriter(ALL_REDUCE) @node_rewriter([CAReduce])
def local_reduce_broadcastable(fgraph, node): def local_reduce_broadcastable(fgraph, node):
"""Remove reduction over broadcastable dimensions.""" """Remove reduction over broadcastable dimensions."""
(reduced,) = node.inputs (reduced,) = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论