提交 f3fc3be2 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2338 from abergeron/set_to_inc

Convert SetSubtensor to IncSubtensor when possible
......@@ -29,7 +29,8 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1,
AdvancedIncSubtensor,
AdvancedSubtensor1)
AdvancedSubtensor1,
advanced_inc_subtensor1)
from theano import scalar
from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file
......@@ -1874,6 +1875,38 @@ def local_useless_inc_subtensor(node):
return [Subtensor(node.op.idx_list)(*node.inputs[1:])]
@register_canonicalize
@gof.local_optimizer([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(node):
"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
"""
if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc == True and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
addn = node.inputs[1].owner
subn = None
other = None
if (addn.inputs[0].owner and
isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[0].owner
other = addn.inputs[1]
elif (addn.inputs[1].owner and
isinstance(addn.inputs[1].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[1].owner
other = addn.inputs[0]
else:
return
if (subn.inputs[1] != node.inputs[2] or
subn.inputs[0] != node.inputs[0]):
return
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
@register_canonicalize
@register_specialize
@gof.local_optimizer([Subtensor])
......
......@@ -2697,6 +2697,35 @@ def test_local_IncSubtensor_serialize():
for inp in a.inputs])
def test_local_set_to_inc_subtensor():
v = theano.tensor.fmatrix()
s = v[[2, 1]]
g = s + 3
r = theano.tensor.set_subtensor(s, g)
moder = compile.get_default_mode().excluding('local_set_to_inc_subtensor')
modet = compile.get_default_mode().including('local_set_to_inc_subtensor')
f1 = theano.function([v], r, mode=moder)
f2 = theano.function([v], r, mode=modet)
advi1 = [n for n in f1.maker.fgraph.toposort()
if isinstance(n.op, tensor.AdvancedIncSubtensor1)]
advi2 = [n for n in f2.maker.fgraph.toposort()
if isinstance(n.op, tensor.AdvancedIncSubtensor1)]
# We only have SetSubtensor in f1
assert all(n.op.set_instead_of_inc for n in advi1)
# We don't have any SetSubtensor in f2
assert all(not n.op.set_instead_of_inc for n in advi2)
val = numpy.random.randn(3, 2).astype('float32')
r1 = f1(val)
r2 = f2(val)
utt.assert_allclose(r1, r2)
def test_local_subtensor_of_dot():
m1 = theano.tensor.matrix()
m2 = theano.tensor.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论