提交 f3face81 authored 作者: James Bergstra's avatar James Bergstra

ENH: tensor.opt.local_sum_broadcastable

上级 572bf565
......@@ -3136,6 +3136,21 @@ def local_cut_useless_reduce(node):
return [summed]
@register_canonicalize
@gof.local_optimizer([])
def local_sum_broadcastable(node):
"""Remove reduction over broadcastable dimensions"""
if isinstance(node.op, T.CAReduce) and node.op.axis is not None:
reduced, = node.inputs
axis = list(node.op.axis)
cuttable = [a for a in axis if reduced.broadcastable[a]]
if cuttable == axis:
# -- in this case we can remove the reduction completely
pattern = [p for p in range(reduced.ndim) if p not in cuttable]
rval = reduced.dimshuffle(*pattern)
return [rval]
@register_specialize
@gof.local_optimizer([])
def local_sum_alloc(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论