提交 aff13140 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add an optimization to convert set_subtensor() of an add into inc_subtensor().

This did come up in an implementation of adagrad at some point and will be much faster on GPU.
上级 1bafa2d4
...@@ -29,7 +29,8 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -29,7 +29,8 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant, Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedSubtensor1) AdvancedSubtensor1,
advanced_inc_subtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1874,6 +1875,34 @@ def local_useless_inc_subtensor(node): ...@@ -1874,6 +1875,34 @@ def local_useless_inc_subtensor(node):
return [Subtensor(node.op.idx_list)(*node.inputs[1:])] return [Subtensor(node.op.idx_list)(*node.inputs[1:])]
@register_canonicalize
@gof.local_optimizer([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(node):
if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc == True and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
addn = node.inputs[1].owner
subn = None
other = None
if (addn.inputs[0].owner and
isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[0].owner
other = addn.inputs[1]
elif (addn.inputs[1].owner and
isinstance(addn.inputs[1].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[1].owner
other = addn.inputs[0]
else:
return
if (subn.inputs[1] != node.inputs[2] or
subn.inputs[0] != node.inputs[0]):
return
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论