提交 6f34947f authored 作者: James Bergstra's avatar James Bergstra

Added specialize optimization local_sum_mul_by_scalar

上级 8edd2804
......@@ -880,6 +880,37 @@ def local_neg_to_mul(node):
return [T.mul(-1, node.inputs[0])]
register_canonicalize(local_neg_to_mul)
@register_specialize
@gof.local_optimizer([])
def local_sum_mul_by_scalar(node):
"""sum(scalar * smth) -> scalar * sum(smth)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, T.Sum):
thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs
scalars = [t.dimshuffle() for t in terms if numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if scalars:
if len(scalars) > 1:
if len(non_scalars) > 1:
return [T.mul(T.mul(*scalars), node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(T.mul(*scalars), node.op(non_scalars[0]))]
else:
return [T.mul(*scalars)]
else:
if len(non_scalars) > 1:
return [T.mul(scalars[0], node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(scalars[0], node.op(non_scalars[0]))]
else:
return [scalars[0]]
if thing_summed.owner and thing_summed.owner.op == T.neg:
return [T.neg(node.op(thing_summed.owner.inputs[0]))]
@gof.local_optimizer([T.mul])
def local_mul_to_neg(node):
if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论