提交 a6aff9ab authored 作者: Kelvin Xu's avatar Kelvin Xu 提交者: Kelvin Xu

share local alloc code

上级 eeb7e6ec
...@@ -3865,7 +3865,7 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -3865,7 +3865,7 @@ def local_sum_prod_mul_by_scalar(node):
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
if isinstance(node.op, T.Sum) or isinstance(node.op, T.prod): if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
node_inps, = node.inputs node_inps, = node.inputs
if node_inps.owner and node_inps.owner.op == T.mul: if node_inps.owner and node_inps.owner.op == T.mul:
terms = node_inps.owner.inputs terms = node_inps.owner.inputs
...@@ -3998,43 +3998,43 @@ def local_sum_div_dimshuffle(node): ...@@ -3998,43 +3998,43 @@ def local_sum_div_dimshuffle(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Sum, T.elemwise.prod]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_prod_all_to_none(node): def local_sum_prod_all_to_none(node):
"""Sum{0,1,...N} -> Sum{} or """Sum{0,1,...N} -> Sum{} or
Prod{0,1,...N} -> Prod{} Prod{0,1,...N} -> Prod{}
""" """
if isinstance(node.op, T.Sum) or isinstance(node.opt, T.elemwise.prod): if isinstance(node.op, T.Sum) or isinstance(node.opt, T.elemwise.Prod):
opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
# if all the axes are named, then use None as a shorthand # if all the axes are named, then use None as a shorthand
# this permits more merging # this permits more merging
if node.op.axis is None: if node.op.axis is None:
return return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)): if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [node.op(axis=None, dtype=node.op.dtype)(node.inputs[0])] return [opt_type(axis=None, dtype=node.op.dtype)(node.inputs[0])]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Sum, T.elemwise.Prod]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_op_op(node): def local_op_of_op(node):
""" """
Prod(Prod()) -> Prod Prod(Prod()) -> single Prod()
or or
Sum(Sum()) -> Sum Sum(Sum()) -> single Sum()
""" """
if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum) : if isinstance(node.op, T.elemwise.Prod) or isinstance(node.op, T.Sum):
node_inps = node.inputs opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
node_inps, = node.inputs
out_dtype = node.op.dtype out_dtype = node.op.dtype
# We manipulate the graph so this is done to make sure the opt # We manipulate the graph so this is done to make sure the opt
# doesn't affect other computations. # doesn't affect other computations.
if len(node_inps.clients) == 1: if len(node_inps.clients) == 1:
if (node_inps.owner and if (node_inps.owner and (isinstance(node_inps.owner.op, T.elemwise.Prod)
(isinstance(node_inps.owner.op, T.elemwise.Prod) or or isinstance(node_inps.owner.op, T.elemwise.Sum))):
isinstance(node_inps.owner.op, T.Sum))
):
# check to see either the inner or outer prod is doing a # check to see either the inner or outer prod is doing a
# product over all axis, in which case we can remove it # product over all axis, in which case we can remove it
if node_inps.owner.op.axis is None or node.op.axis is None: if node_inps.owner.op.axis is None or node.op.axis is None:
return [node.op(None, dtype=out_dtype)( return [opt_type(None, dtype=out_dtype)(
node_inps.owner.inputs[0])] node_inps.owner.inputs[0])]
# figure out which axes were in the original sum # figure out which axes were in the original sum
...@@ -4078,7 +4078,7 @@ def local_op_op(node): ...@@ -4078,7 +4078,7 @@ def local_op_op(node):
"been fixed) set the theano flag " "been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.") "`warn.sum_sum_bug` to False.")
combined = node.op(newaxis, dtype=out_dtype) combined = opt_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])] return [combined(node_inps.owner.inputs[0])]
...@@ -4223,57 +4223,28 @@ def local_reduce_broadcastable(node): ...@@ -4223,57 +4223,28 @@ def local_reduce_broadcastable(node):
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum, T.elemwise.Prod]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_alloc(node): def local_opt_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)""" """ sum(alloc(constant,shapes...)) => constant*prod(shapes)
if isinstance(node.op, T.Sum): or
summed, = node.inputs prod(alloc(constant,shapes...)) => constant**prod(shapes)
if summed.owner and isinstance(summed.owner.op, T.Alloc): """
input = summed.owner.inputs[0] if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
shapes = summed.owner.inputs[1:] node_inps, = node.inputs
if node_inps.owner and isinstance(node_inps.owner.op, T.Alloc):
input = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
if (node.op.axis is None or if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))): node.op.axis == tuple(range(input.ndim))):
try: try:
val = get_scalar_constant_value(input) val = get_scalar_constant_value(input)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] * T.mul(*shapes) # check which type of op
if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * T.mul(*shapes)
else:
val = val.reshape(1)[0] ** T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)] return [T.cast(val, dtype=node.outputs[0].dtype)]
except NotScalarConstantError:
pass
else:
try:
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis]
if to_prod:
val *= T.mul(*to_prod)
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except NotScalarConstantError:
pass
# I guess in this opt it might make sense to make a general local_opt_alloc?
# and in the code check if it is a prod or a sum and do the corresponding multiplication
# or exponentiation?
@register_specialize
@gof.local_optimizer([T.elemwise.Prod])
def local_prod_alloc(node):
""" prod(alloc(constant,shapes...)) => constant**prod(shapes)"""
if isinstance(node.op, T.elemwise.Prod):
prod_inps, = node.inputs
if prod_inps.owner and isinstance(summed.owner.op, T.Alloc):
input = prod_inps.owner.inputs[0]
shapes = prod_inps.owner.inputs[1:]
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
try:
val = get_scalar_constant_value(input)
assert val.size == 1
val = val.reshape(1)[0] ** T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)]
except NotScalarConstantError: except NotScalarConstantError:
pass pass
else: else:
...@@ -4284,7 +4255,10 @@ def local_prod_alloc(node): ...@@ -4284,7 +4255,10 @@ def local_prod_alloc(node):
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis] if i in node.op.axis]
if to_prod: if to_prod:
val = val ** T.mul(*to_prod) if isintance(node.op, T.Sum):
val *= T.mul(*to_prod)
else:
val = val ** T.mul(*to_prod)
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype), return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
*[shapes[i] for i in xrange(len(shapes)) *[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])] if i not in node.op.axis])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论