提交 f3fc3be2 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2338 from abergeron/set_to_inc

Convert SetSubtensor to IncSubtensor when possible
...@@ -29,7 +29,8 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -29,7 +29,8 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant, Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedSubtensor1) AdvancedSubtensor1,
advanced_inc_subtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1345,10 +1346,10 @@ def local_subtensor_make_vector(node): ...@@ -1345,10 +1346,10 @@ def local_subtensor_make_vector(node):
replace all subtensor(make_vector) like: replace all subtensor(make_vector) like:
[a,b,c][0] -> a [a,b,c][0] -> a
[a,b,c][0:2] -> [a,b] [a,b,c][0:2] -> [a,b]
replace all AdvancedSubtensor1(make_vector) like: replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c] [a,b,c][[0,2]] -> [a,c]
we can do this for constant indexes we can do this for constant indexes
""" """
x = node.inputs[0] x = node.inputs[0]
...@@ -1629,13 +1630,13 @@ def local_elemwise_alloc(node): ...@@ -1629,13 +1630,13 @@ def local_elemwise_alloc(node):
if len(node.outputs) > 1: if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern # Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true. # This is a supposition that I'm not sure is always true.
assert all([o.type.broadcastable == assert all([o.type.broadcastable ==
node.outputs[0].type.broadcastable for o in node.outputs[0].type.broadcastable for o in
node.outputs[1:]]) node.outputs[1:]])
# The broadcast pattern of the ouptut must match the broadcast pattern of # The broadcast pattern of the ouptut must match the broadcast pattern of
# at least one of the inputs. # at least one of the inputs.
if not any([i.type.broadcastable == if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable for i in node.inputs]): node.outputs[0].type.broadcastable for i in node.inputs]):
return False return False
...@@ -1874,6 +1875,38 @@ def local_useless_inc_subtensor(node): ...@@ -1874,6 +1875,38 @@ def local_useless_inc_subtensor(node):
return [Subtensor(node.op.idx_list)(*node.inputs[1:])] return [Subtensor(node.op.idx_list)(*node.inputs[1:])]
@register_canonicalize
@gof.local_optimizer([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(node):
"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
"""
if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc == True and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
addn = node.inputs[1].owner
subn = None
other = None
if (addn.inputs[0].owner and
isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[0].owner
other = addn.inputs[1]
elif (addn.inputs[1].owner and
isinstance(addn.inputs[1].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[1].owner
other = addn.inputs[0]
else:
return
if (subn.inputs[1] != node.inputs[2] or
subn.inputs[0] != node.inputs[0]):
return
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
......
...@@ -2697,6 +2697,35 @@ def test_local_IncSubtensor_serialize(): ...@@ -2697,6 +2697,35 @@ def test_local_IncSubtensor_serialize():
for inp in a.inputs]) for inp in a.inputs])
def test_local_set_to_inc_subtensor():
v = theano.tensor.fmatrix()
s = v[[2, 1]]
g = s + 3
r = theano.tensor.set_subtensor(s, g)
moder = compile.get_default_mode().excluding('local_set_to_inc_subtensor')
modet = compile.get_default_mode().including('local_set_to_inc_subtensor')
f1 = theano.function([v], r, mode=moder)
f2 = theano.function([v], r, mode=modet)
advi1 = [n for n in f1.maker.fgraph.toposort()
if isinstance(n.op, tensor.AdvancedIncSubtensor1)]
advi2 = [n for n in f2.maker.fgraph.toposort()
if isinstance(n.op, tensor.AdvancedIncSubtensor1)]
# We only have SetSubtensor in f1
assert all(n.op.set_instead_of_inc for n in advi1)
# We don't have any SetSubtensor in f2
assert all(not n.op.set_instead_of_inc for n in advi2)
val = numpy.random.randn(3, 2).astype('float32')
r1 = f1(val)
r2 = f2(val)
utt.assert_allclose(r1, r2)
def test_local_subtensor_of_dot(): def test_local_subtensor_of_dot():
m1 = theano.tensor.matrix() m1 = theano.tensor.matrix()
m2 = theano.tensor.matrix() m2 = theano.tensor.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论