提交 1db72747 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4355 from Saizheng/master

#2801:subtensor-incsubtensor
...@@ -1895,6 +1895,33 @@ def local_track_shape_i(node): ...@@ -1895,6 +1895,33 @@ def local_track_shape_i(node):
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
@register_specialize
@register_canonicalize
@gof.local_optimizer([Subtensor])
def local_subtensor_inc_subtensor(node):
"""
Subtensor(SetSubtensor(x, y, idx), idx) -> y
"""
if isinstance(node.op, Subtensor):
x = node.inputs[0]
if not x.owner or not isinstance(x.owner.op, IncSubtensor):
return
if not x.owner.op.set_instead_of_inc:
return
if x.owner.inputs[2:] == node.inputs[1:] and tuple(x.owner.op.idx_list) == tuple(node.op.idx_list):
# if x[idx] and y have the same ndim (and shape), directly return y
if x.owner.inputs[0].ndim - (len(node.op.idx_list) - sum([isinstance(idx, slice) for idx in node.op.idx_list])) == x.owner.inputs[1].ndim:
return [x.owner.inputs[1]]
# else y is broadcastable, return alloc of broadcastable y
else:
x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:])
return [T.alloc(x.owner.inputs[1], *x_subtensor.shape)]
else:
return
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
......
...@@ -1934,6 +1934,78 @@ def test_local_subtensor_remove_broadcastable_index(): ...@@ -1934,6 +1934,78 @@ def test_local_subtensor_remove_broadcastable_index():
f2(xn) f2(xn)
def test_subtensor_inc_subtensor():
# basic test
x = tensor.matrix('x')
i = tensor.iscalar('i')
v = tensor.vector('v')
y = tensor.set_subtensor(x[i], v)
z = y[i]
mode = theano.compile.mode.get_default_mode().including('local_subtensor_inc_subtensor')
f = theano.function([x, i, v], z, mode=mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, DeepCopyOp)
# basic test, numerical check
x_ = numpy.random.uniform(size=[3, 4]).astype(config.floatX)
v_ = numpy.random.uniform(size=[4, ]).astype(config.floatX)
i_ = 1
assert numpy.array_equal(f(x_, i_, v_), v_)
# complicated test
x = tensor.tensor4('x')
i1 = tensor.iscalar('i1')
i2 = tensor.iscalar('i2')
i3 = tensor.iscalar('i3')
i4 = tensor.iscalar('i4')
v = tensor.tensor3('v')
y = tensor.set_subtensor(x[i1, :i2, i3:, ::i4], v)
z = y[i1, :i2, i3:, ::i4]
mode = theano.compile.mode.get_default_mode().including('local_subtensor_inc_subtensor')
f = theano.function([x, i1, i2, i3, i4, v], z, mode=mode)
prog = f.maker.fgraph.toposort()
assert len(prog) == 1
assert isinstance(prog[0].op, DeepCopyOp)
# complicated test, numerical check
x_ = numpy.random.uniform(size=[3, 4, 5, 6]).astype(config.floatX)
v_ = numpy.random.uniform(size=[2, 2, 2]).astype(config.floatX)
i1_, i2_, i3_, i4_ = 1, 2, 3, 4
assert numpy.array_equal(f(x_, i1_, i2_, i3_, i4_, v_), v_)
# case not use this optimization
z = y[i1, :i3, i2:, ::i4]
mode = theano.compile.mode.get_default_mode().including('local_subtensor_inc_subtensor')
f = theano.function([x, i1, i2, i3, i4, v], z, mode=mode)
prog = f.maker.fgraph.toposort()
assert len(prog) != 1
assert any(isinstance(x.op, tensor.IncSubtensor) for x in prog)
assert any(isinstance(x.op, tensor.Subtensor) for x in prog)
# case not use this optimization, numerical check
x_ = numpy.random.uniform(size=[3, 4, 5, 6]).astype(config.floatX)
v_ = numpy.random.uniform(size=[2, 2, 2]).astype(config.floatX)
i1_, i2_, i3_, i4_ = 1, 2, 3, 4
x_[i1_, :i2_, i3_:, ::i4_] = v_
assert numpy.array_equal(f(x_, i1_, i2_, i3_, i4_, v_), x_[i1_, :i3_, i2_:, ::i4_])
# case when v is broadcastable
x = tensor.matrix('x')
i1 = tensor.iscalar('i')
i2 = tensor.iscalar('i')
v = tensor.vector('v')
y = tensor.set_subtensor(x[:i1, :i2], v)
z = y[:i1, :i2]
mode = theano.compile.mode.get_default_mode().including('local_subtensor_inc_subtensor')
f = theano.function([x, i1, i2, v], z, mode=mode)
prog = f.maker.fgraph.toposort()
assert any(isinstance(x.op, tensor.Alloc) for x in prog)
# case when v is broadcastable, numerical check
x_ = numpy.random.uniform(size=[3, 4]).astype(config.floatX)
v_ = numpy.random.uniform(size=[2, ]).astype(config.floatX)
i1_, i2_ = 2, 2
x_[:i1_, :i2_] = v_
assert numpy.array_equal(f(x_, i1_, i2_, v_), x_[:i1_, :i2_])
class test_local_subtensor_make_vector(unittest.TestCase): class test_local_subtensor_make_vector(unittest.TestCase):
def test_scalar_idx(self): def test_scalar_idx(self):
x, y, z = tensor.lscalars('xyz') x, y, z = tensor.lscalars('xyz')
...@@ -6730,7 +6802,7 @@ if __name__ == '__main__': ...@@ -6730,7 +6802,7 @@ if __name__ == '__main__':
t.setUp() t.setUp()
# t.test_perform() # t.test_perform()
t.test_infer_shape() t.test_infer_shape()
test_subtensor_inc_subtensor()
""" """
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论