提交 dd5a607f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make the reduce tests pass and fix reference to old alias for shape

上级 451ba5fa
......@@ -1203,9 +1203,9 @@ def local_useless_alloc(node):
@register_specialize
@register_canonicalize
@gof.local_optimizer([T._shape])
@gof.local_optimizer([T.shape])
def local_shape_to_shape_i(node):
if node.op == T._shape:
if node.op == T.shape:
# This optimization needs ShapeOpt and fgraph.shape_feature
if not hasattr(node.fgraph, 'shape_feature'):
return
......@@ -3270,9 +3270,12 @@ def local_sum_sum(node):
combined_sum = T.Sum(newaxis, dtype=out_dtype)
return [combined_sum(summed.owner.inputs[0])]
ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
T.elemwise.Sum, T.elemwise.Prod,
T.elemwise.ProdWithoutZeros]
@register_canonicalize
@gof.local_optimizer([T.CAReduce])
@gof.local_optimizer(ALL_REDUCE)
def local_cut_useless_reduce(node):
"""Sum(a, axis=[]) -> a """
if isinstance(node.op, T.CAReduce):
......@@ -3288,7 +3291,7 @@ def local_cut_useless_reduce(node):
#
#@register_canonicalize
@register_specialize
@gof.local_optimizer([T.CAReduce])
@gof.local_optimizer(ALL_REDUCE)
def local_reduce_broadcastable(node):
"""Remove reduction over broadcastable dimensions"""
if isinstance(node.op, T.CAReduce):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论