提交 1715ae15 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the case where there is a single array to index by.

上级 29b9bcc2
...@@ -502,6 +502,7 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor): ...@@ -502,6 +502,7 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
x = x.reshape(nshp) x = x.reshape(nshp)
narrays = 0
transp = list(range(x.ndim)) transp = list(range(x.ndim))
p = 0 p = 0
for k, i in enumerate(list(nidx)): for k, i in enumerate(list(nidx)):
...@@ -512,6 +513,7 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor): ...@@ -512,6 +513,7 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
i = nidx.pop(k) i = nidx.pop(k)
nidx.insert(p, i) nidx.insert(p, i)
p += 1 p += 1
narrays += 1
x = x.transpose(*transp) x = x.transpose(*transp)
idx_ = ([slice(None)] * p + nidx[p:]) idx_ = ([slice(None)] * p + nidx[p:])
...@@ -535,7 +537,17 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor): ...@@ -535,7 +537,17 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
# finish up # finish up
out_flat_shp = take_idx.shape + x.shape[p:] out_flat_shp = take_idx.shape + x.shape[p:]
out[0] = out_flat.reshape(out_flat_shp) o = out_flat.reshape(out_flat_shp)
# If there was only one array we need to move the indexed
# dimension back
if narrays == 1:
k = transp[0]
ntransp = list(range(1, o.ndim))
ntransp.insert(k, 0)
o = o.transpose(*ntransp)
out[0] = o
class GpuAdvancedIncSubtensor1(Op): class GpuAdvancedIncSubtensor1(Op):
......
...@@ -55,6 +55,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -55,6 +55,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
inc_sub=tensor.IncSubtensor, inc_sub=tensor.IncSubtensor,
adv_sub1=tensor.AdvancedSubtensor1, adv_sub1=tensor.AdvancedSubtensor1,
adv_incsub1=tensor.AdvancedIncSubtensor1, adv_incsub1=tensor.AdvancedIncSubtensor1,
adv_sub=tensor.AdvancedSubtensor,
mode=None, mode=None,
dtype=theano.config.floatX, dtype=theano.config.floatX,
type=tensor.TensorType, type=tensor.TensorType,
...@@ -65,6 +66,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -65,6 +66,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.inc_sub = inc_sub self.inc_sub = inc_sub
self.adv_sub1 = adv_sub1 self.adv_sub1 = adv_sub1
self.adv_incsub1 = adv_incsub1 self.adv_incsub1 = adv_incsub1
self.adv_sub = adv_sub
self.dimshuffle = dimshuffle self.dimshuffle = dimshuffle
if mode is None: if mode is None:
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
...@@ -359,7 +361,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -359,7 +361,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
numpy_version = [int(v) for v in numpy.version.version.split('.')[0:2]] numpy_version = [int(v) for v in numpy.version.version.split('.')[0:2]]
if numpy_version >= [1, 9]: if numpy_version >= [1, 9]:
test_cases.append( test_cases.append(
(1, AdvancedSubtensor, AdvancedSubtensor, (1, AdvancedSubtensor, self.adv_sub,
numpy.index_exp[..., numpy.newaxis, [1, 2]])) numpy.index_exp[..., numpy.newaxis, [1, 2]]))
for length, op_type, op_type_opt, slice_ in test_cases: for length, op_type, op_type_opt, slice_ in test_cases:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论