提交 b5a0c60d authored 作者: James Bergstra's avatar James Bergstra

Added sum-related canonicalizations.

上级 71b70e33
......@@ -1117,6 +1117,60 @@ def local_sum_mul_by_scalar(node):
if thing_summed.owner and thing_summed.owner.op == T.neg:
return [T.neg(node.op(thing_summed.owner.inputs[0]))]
@register_canonicalize
@gof.local_optimizer([])
def local_sum_all_to_none(node):
"""Sum{0,1,...N} -> Sum{}"""
if isinstance(node.op, T.Sum):
# if all the axes are named, then use None as a shorthand
# this permits more merging
if node.op.axis is None:
return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [T.Sum(axis=None)(node.inputs[0])]
@register_canonicalize
@gof.local_optimizer([])
def local_sum_sum(node):
"""Sum(Sum()) -> Sum"""
if isinstance(node.op, T.Sum):
summed, = node.inputs
if len(summed.clients) == 1:
if summed.owner and isinstance(summed.owner.op, T.Sum):
if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce
return [T.Sum(None)(summed.owner.inputs[0])]
if node.op.axis is None:
# we're summing up everything anyway so lets
# do it all at once
return [T.Sum(None)(summed.owner.inputs[0])]
# figure out which dimensions of the original input are preserved
alldims = range(summed.owner.inputs[0].type.ndim)
# trim out the dimensions that were removed by the first sum
alldims = [d for i,d in enumerate(alldims) if i in summed.owner.op.axis]
# trim out the dimensions removed by second sum
alldims = [d for i,d in enumerate(alldims) if i in node.op.axis]
# figure out an axis argument that combines the effect of both
newaxis = [i for i in xrange(summed.owner.inputs[0].type.ndim)
if i not in alldims]
combined_sum = T.Sum(newaxis)
return [combined_sum(summed.owner.inputs[0])]
@register_canonicalize
@gof.local_optimizer([])
def local_cut_useless_reduce(node):
"""Sum(a, axis=[]) -> a """
if isinstance(node.op, T.CAReduce):
summed, = node.inputs
# if reduce were doing anything, the output ndim would be reduced
if summed.type == node.outputs[0].type:
return [summed]
@gof.local_optimizer([T.mul])
def local_mul_to_neg(node):
if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论