提交 6aa9d88e authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

optimization for prod added

上级 856aa0b6
...@@ -4130,39 +4130,40 @@ def local_elemwise_sub_zeros(node): ...@@ -4130,39 +4130,40 @@ def local_elemwise_sub_zeros(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
def local_sum_div_dimshuffle(node): def local_sum_prod_div_dimshuffle(node):
""" """
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'. if dimension l of the DimShuffle is 'x'
or
prod(a / dimshuffle{...}(b), axis=l) ->
prod(a, axis={...}) / b ** a.shape[l],
if dimension l of the DimShuffle is 'x'
""" """
# TODO: extend it to product, and quotient of products
# It does not make much sense now to extend it to the case where the # It does not make much sense now to extend it to the case where the
# dimshuffle is in the numerator, since elemwise inversion of the # dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation. # denominator would still be needed before the summation or production.
if isinstance(node.op, T.Sum): if isinstance(node.op, (T.Sum, T.elemwise.Prod)):
axis = node.op.axis axis = node.op.axis
if axis is None: if axis is None:
axis = list(range(node.inputs[0].ndim)) axis = list(range(node.inputs[0].ndim))
# print 'axis =', axis node_input = node.inputs[0]
thing_summed = node.inputs[0] if node_input.owner and node_input.owner.op == T.true_div:
if thing_summed.owner and thing_summed.owner.op == T.true_div: numerator, denominator = node_input.owner.inputs
numerator, denominator = thing_summed.owner.inputs
# Old, bugged logic, reproduced here only to warn users # Old, bugged logic, reproduced here only to warn users
if config.warn.sum_div_dimshuffle_bug: if (config.warn.sum_div_dimshuffle_bug and
if numerator.owner and isinstance(numerator.owner.op, isinstance(node.op, T.Sum) and
T.DimShuffle): numerator.owner and
isinstance(numerator.owner.op, T.DimShuffle)):
# Check compatibility
new_order = numerator.owner.op.new_order new_order = numerator.owner.op.new_order
compatible_dims = True
for ax in axis: for ax in axis:
if len(new_order) <= ax or new_order[ax] != 'x': if ax < len(new_order) and new_order[ax] == 'x':
compatible_dims = False
break
if compatible_dims:
_logger.warn('WARNING: Your current code is fine, but' _logger.warn('WARNING: Your current code is fine, but'
' Theano versions between ' ' Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and' 'rev. 3bd9b789f5e8 (2010-06-16) and'
...@@ -4171,37 +4172,49 @@ def local_sum_div_dimshuffle(node): ...@@ -4171,37 +4172,49 @@ def local_sum_div_dimshuffle(node):
'To disable this warning, set the Theano' 'To disable this warning, set the Theano'
' flag warn.sum_div_dimshuffle_bug to' ' flag warn.sum_div_dimshuffle_bug to'
' False.') ' False.')
break
if denominator.owner and isinstance(denominator.owner.op, if denominator.owner and isinstance(denominator.owner.op,
T.DimShuffle): T.DimShuffle):
thing_dimshuffled = denominator.owner.inputs[0] dimshuffle_input = denominator.owner.inputs[0]
new_order = denominator.owner.op.new_order dimshuffle_order = denominator.owner.op.new_order
# print 'new_order =', new_order
# check compatibility
compatible_dims = True
for ax in axis:
# print 'ax =', ax
# print 'len(new_order) =', len(new_order)
# print 'new_order[ax] =', new_order[ax]
if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False
break
if compatible_dims: compatible_dims = []
# print 'getting denom out' incompatible_dims = []
# Keep needed dimensions for new dimshuffle for ax in axis:
new_new_order = list(ax for i, ax in enumerate(new_order) if (ax < len(dimshuffle_order) and
if i not in axis or ax != 'x') dimshuffle_order[ax] == 'x'):
# print 'new_new_order =', new_new_order compatible_dims.append(ax)
# Remove useless rebroadcast axes
while len(new_new_order) > 0 and new_new_order[0] == 'x':
del new_new_order[0]
# print 'new_new_order =', new_new_order
if all(i == e for i, e in enumerate(new_new_order)):
new_denom = thing_dimshuffled
else: else:
if config.warn.sum_div_dimshuffle_bug: incompatible_dims.append(ax)
reordered_incompatible_dims = []
for ic_ax in incompatible_dims:
reordered_incompatible_dims.append(
ic_ax - sum(
[1 for c_ax in compatible_dims if c_ax < ic_ax]))
if len(compatible_dims) > 0:
optimized_dimshuffle_order = list(
ax for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != 'x'))
# Removing leading 'x' (since it will be done automatically)
while (len(optimized_dimshuffle_order) > 0 and
optimized_dimshuffle_order[0] == 'x'):
del optimized_dimshuffle_order[0]
# if optimized_dimshuffle_order is sorted with
# not 'x', then dimshuffle is useless.
if all(i == e for i, e in
enumerate(optimized_dimshuffle_order)):
optimized_dimshuffle = dimshuffle_input
else:
optimized_dimshuffle = T.DimShuffle(
dimshuffle_input.type.broadcastable,
optimized_dimshuffle_order)(dimshuffle_input)
if (config.warn.sum_div_dimshuffle_bug and
isinstance(node.op, T.Sum)):
_logger.warn('WARNING: Your current code is fine,' _logger.warn('WARNING: Your current code is fine,'
' but Theano versions between ' ' but Theano versions between '
'rev. 3bd9b789f5e8 (2010-06-16) and' 'rev. 3bd9b789f5e8 (2010-06-16) and'
...@@ -4212,12 +4225,27 @@ def local_sum_div_dimshuffle(node): ...@@ -4212,12 +4225,27 @@ def local_sum_div_dimshuffle(node):
'warn.sum_div_dimshuffle_bug' 'warn.sum_div_dimshuffle_bug'
' to False.') ' to False.')
new_denom = T.DimShuffle( if isinstance(node.op, T.Sum):
thing_dimshuffled.type.broadcastable, op_on_compatible_dims = T.sum(
new_new_order)(thing_dimshuffled) numerator, axis=compatible_dims)
return [T.true_div(node.op(numerator), new_denom)] div_op = T.true_div(
# else: op_on_compatible_dims,
# print 'incompatible dims:', axis, new_order optimized_dimshuffle)
op_on_incompatible_dims = T.sum(
div_op,
axis=reordered_incompatible_dims)
elif isinstance(node.op, T.elemwise.Prod):
op_on_compatible_dims = T.prod(
numerator, axis=compatible_dims)
div_op = T.true_div(
op_on_compatible_dims,
(optimized_dimshuffle **
T.prod([numerator.shape[ax]
for ax in compatible_dims])))
op_on_incompatible_dims = T.prod(
div_op,
axis=reordered_incompatible_dims)
return [op_on_incompatible_dims]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论