提交 9a0c05a5 authored 作者: Frederic's avatar Frederic

Make AdvancedSubtensor1 reuse preallocated output.

上级 8bb90e12
......@@ -4880,7 +4880,11 @@ class AdvancedSubtensor1(Op):
x, i = inp
out, = out_
# Copy always implied by numpy advanced indexing semantic.
out[0] = x[i]
if out[0] is not None and out[0].shape==(len(i),)+x.shape[1:]:
o = out[0]
else:
o = None
out[0] = x.take(i, axis=0, out=o)
def grad(self, inputs, grads):
gz, = grads
......
......@@ -1809,6 +1809,8 @@ class T_subtensor(unittest.TestCase):
self.inc_sub = inc_sub
self.adv_sub1 = adv_sub1
self.adv_incsub1 = adv_incsub1
if mode is None:
mode = theano.compile.mode.get_default_mode()
self.mode = mode
self.dtype = dtype
self.ignore_topo = ignore_topo
......@@ -2093,6 +2095,19 @@ class T_subtensor(unittest.TestCase):
self.assertTrue(val.ndim == data.ndim)
self.assertTrue(numpy.allclose(val, good), (val, good))
# Test reuse of output memory
if isinstance(self.adv_sub1,tensor.AdvancedSubtensor1):
op = self.adv_sub1()
# When idx is a TensorConstant.
if hasattr(idx, "data"):
idx = idx.data
test_out = [[None]]
op.perform(None, [data, idx],test_out)
out1 = test_out[0][0]
op.perform(None, [data, idx],test_out)
out2 = test_out[0][0]
assert out1 is out2
def test_err_invalid_list(self):
n = self.shared(numpy.asarray(5, dtype=self.dtype))
self.assertRaises(TypeError, n.__getitem__, [0,0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论