提交 eeb7e6ec authored 作者: Kelvin Xu's avatar Kelvin Xu

sum_alloc, sum_mul_by scalar, sum_all_to_non, sum_sum

上级 b75cf2e1
......@@ -3853,17 +3853,22 @@ register_canonicalize(local_neg_to_mul)
@register_specialize
@gof.local_optimizer([T.Sum])
def local_sum_mul_by_scalar(node):
@gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_prod_mul_by_scalar(node):
"""sum(scalar * smth) -> scalar * sum(smth)
sum(-smth) -> -sum(smth)
or
prod(scalar * smth) -> scalar * prod(smth)
prod(-smth) -> -prod(smth)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, T.Sum):
thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs
if isinstance(node.op, T.Sum) or isinstance(node.op, T.prod):
node_inps, = node.inputs
if node_inps.owner and node_inps.owner.op == T.mul:
terms = node_inps.owner.inputs
scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
......@@ -3885,8 +3890,8 @@ def local_sum_mul_by_scalar(node):
return [T.mul(scalars[0], node.op(non_scalars[0]))]
else:
return [scalars[0]]
if thing_summed.owner and thing_summed.owner.op == T.neg:
return [T.neg(node.op(thing_summed.owner.inputs[0]))]
if node_inps.owner and node_inps.owner.op == T.neg:
return [T.neg(node.op(node_inps.owner.inputs[0]))]
@register_specialize
......@@ -3993,64 +3998,68 @@ def local_sum_div_dimshuffle(node):
@register_canonicalize
@gof.local_optimizer([T.Sum])
def local_sum_all_to_none(node):
"""Sum{0,1,...N} -> Sum{}"""
if isinstance(node.op, T.Sum):
@gof.local_optimizer([T.Sum, T.elemwise.prod])
def local_sum_prod_all_to_none(node):
"""Sum{0,1,...N} -> Sum{} or
Prod{0,1,...N} -> Prod{}
"""
if isinstance(node.op, T.Sum) or isinstance(node.opt, T.elemwise.prod):
# if all the axes are named, then use None as a shorthand
# this permits more merging
if node.op.axis is None:
return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [T.Sum(axis=None, dtype=node.op.dtype)(node.inputs[0])]
return [node.op(axis=None, dtype=node.op.dtype)(node.inputs[0])]
@register_canonicalize
@gof.local_optimizer([T.Sum])
def local_sum_sum(node):
@gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_op_op(node):
"""
Prod(Prod()) -> Prod
or
Sum(Sum()) -> Sum
"""
if isinstance(node.op, T.Sum):
summed, = node.inputs
if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum) :
node_inps = node.inputs
out_dtype = node.op.dtype
if len(summed.clients) == 1:
if (summed.owner and
isinstance(summed.owner.op, T.Sum)):
if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce
return [T.Sum(None, dtype=out_dtype)(
summed.owner.inputs[0])]
if node.op.axis is None:
# we're summing up everything anyway so lets
# do it all at once
return [T.Sum(None, dtype=out_dtype)(
summed.owner.inputs[0])]
newaxis = list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input
# are preserved
# We manipulate the graph so this is done to make sure the opt
# doesn't affect other computations.
if len(node_inps.clients) == 1:
if (node_inps.owner and
(isinstance(node_inps.owner.op, T.elemwise.Prod) or
isinstance(node_inps.owner.op, T.Sum))
):
# check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it
if node_inps.owner.op.axis is None or node.op.axis is None:
return [node.op(None, dtype=out_dtype)(
node_inps.owner.inputs[0])]
# figure out which axes were in the original sum
newaxis = list(tuple(node_inps.owner.op.axis))
for i in node.op.axis:
new_i = i
for ii in summed.owner.op.axis:
for ii in node_inps.owner.op.axis:
if new_i >= ii:
new_i += 1
assert new_i not in newaxis
newaxis.append(new_i)
assert len(newaxis) == len(list(summed.owner.op.axis) +
assert len(newaxis) == len(list(node_inps.owner.op.axis) +
list(node.op.axis))
# The old bugged logic. We keep it there to generate a warning
# when we generated bad code.
alldims = range(summed.owner.inputs[0].type.ndim)
alldims = range(node_inps.owner.inputs[0].type.ndim)
alldims = [d for i, d in enumerate(alldims) if i
in summed.owner.op.axis]
in node_inps.owner.op.axis]
alldims = [d for i, d in enumerate(alldims)
if i in node.op.axis]
newaxis_old = [i for i in
xrange(summed.owner.inputs[0].type.ndim)
xrange(node_inps.owner.inputs[0].type.ndim)
if i not in alldims]
if (theano.config.warn.sum_sum_bug and
......@@ -4069,8 +4078,9 @@ def local_sum_sum(node):
"been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.")
combined_sum = T.Sum(newaxis, dtype=out_dtype)
return [combined_sum(summed.owner.inputs[0])]
combined = node.op(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
ALL_REDUCE = [T.elemwise.CAReduce, T.elemwise.All, T.elemwise.Any,
T.elemwise.Sum, T.elemwise.Prod,
......@@ -4212,7 +4222,7 @@ def local_reduce_broadcastable(node):
@register_specialize
@gof.local_optimizer([T.Sum])
@gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
if isinstance(node.op, T.Sum):
......@@ -4244,6 +4254,43 @@ def local_sum_alloc(node):
except NotScalarConstantError:
pass
# I guess in this opt it might make sense to make a general local_opt_alloc?
# and in the code check if it is a prod or a sum and do the corresponding multiplication
# or exponentiation?
@register_specialize
@gof.local_optimizer([T.elemwise.Prod])
def local_prod_alloc(node):
""" prod(alloc(constant,shapes...)) => constant**prod(shapes)"""
if isinstance(node.op, T.elemwise.Prod):
prod_inps, = node.inputs
if prod_inps.owner and isinstance(summed.owner.op, T.Alloc):
input = prod_inps.owner.inputs[0]
shapes = prod_inps.owner.inputs[1:]
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
try:
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0] ** T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)]
except NotScalarConstantError:
pass
else:
try:
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis]
if to_prod:
val = val ** T.mul(*to_prod)
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except NotScalarConstantError:
pass
@register_specialize
@gof.local_optimizer([T.neg])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论