提交 127f1000 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

New optimization: sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis=l) / b

上级 bae2c6df
...@@ -1600,6 +1600,56 @@ def local_sum_mul_by_scalar(node): ...@@ -1600,6 +1600,56 @@ def local_sum_mul_by_scalar(node):
if thing_summed.owner and thing_summed.owner.op == T.neg: if thing_summed.owner and thing_summed.owner.op == T.neg:
return [T.neg(node.op(thing_summed.owner.inputs[0]))] return [T.neg(node.op(thing_summed.owner.inputs[0]))]
@register_canonicalize
@gof.local_optimizer([])
def local_sum_div_dimshuffle(node):
'''sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis=l) / b,
if dimension l of the DimShuffle is 'x'.'''
# TODO: extend it to product, and quotient of products
if isinstance(node.op, T.Sum):
axis = node.op.axis
#print 'axis =', axis
thing_summed = node.inputs[0]
dimshuffled = None
if thing_summed.owner and thing_summed.owner.op == T.true_div:
numerator, denominator = thing_summed.owner.inputs
if isinstance(numerator.owner.op, T.DimShuffle):
new_order = numerator.owner.op.new_order
#print 'new_order =', new_order
# check compatibility
compatible_dims = True
for ax in axis:
if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False
break
if compatible_dims:
#print 'getting num out'
new_num = numerator.owner.inputs[0]
return [T.true_div(new_num, node.op(denominator))]
#else:
# print 'incompatible dims:', axis, new_order
if isinstance(denominator.owner.op, T.DimShuffle):
new_order = denominator.owner.op.new_order
#print 'new_order =', new_order
# check compatibility
compatible_dims = True
for ax in axis:
#print 'ax =', ax
#print 'len(new_order) =', len(new_order)
#print 'new_order[ax] =', new_order[ax]
if len(new_order) <= ax or new_order[ax] != 'x':
compatible_dims = False
break
if compatible_dims:
#print 'getting denom out'
new_denom = denominator.owner.inputs[0]
return [T.true_div(node.op(numerator), new_denom)]
#else:
# print 'incompatible dims:', axis, new_order
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_all_to_none(node): def local_sum_all_to_none(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论