提交 ff91f22e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add some tests for silly numpy behaviour and make them pass.

上级 296fabec
......@@ -519,16 +519,24 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
p += 1
narrays += 1
else:
try:
i.__index__()
# We shift back the position of the array by the
# number of dimensions that are removed by
# indexing. If ap is bigger than 0 it means we
# have encountered at least one array.
if ap >= 0:
ap -= 1
except Exception:
pass
if narrays == 0:
try:
i.__index__()
# We shift back the position of the array by the
# number of dimensions that are removed by
# indexing. If ap is bigger than 0 it means we
# have encountered at least one array.
if ap >= 0:
ap -= 1
# If this index is before the first array then
# we will not move the array back to its
# position. Mark this by faking that there
# are more than two arrays. This is crazy
# numpy behaviour so blame them.
if narrays == 0:
narrays = 2
except Exception:
pass
x = x.transpose(*transp)
......@@ -556,11 +564,11 @@ class GpuAdvancedSubtensor(HideC, tensor.AdvancedSubtensor):
o = out_flat.reshape(out_flat_shp)
# If there was only one array we need to move the indexed
# dimension back
# dimension(s) back to the position of the array, which is
# stored in ap. Note that ap is invalid is narrays != 1.
if narrays == 1:
k = ap
ntransp = list(range(1, o.ndim))
ntransp.insert(k, 0)
ntransp = list(range(take_idx.ndim, o.ndim))
ntransp[ap:ap] = list(range(take_idx.ndim))
o = o.transpose(*ntransp)
out[0] = o
......
......@@ -1435,6 +1435,42 @@ class TestAdvancedSubtensor(unittest.TestCase):
rval = ft4v[0, :, ix2v, :]
utt.assert_allclose(rval, aval)
def test_adv_subtensor_w_none_and_matrix(self):
subt = self.ft4[:, None, :, self.ix2, :]
f = theano.function([self.ft4, self.ix2], subt, mode=self.mode)
ft4v = numpy.random.random((2, 3, 4, 5)).astype('float32')
ix2v = numpy.asarray([[0, 1], [1, 0]])
aval = f(ft4v, ix2v)
rval = ft4v[:, None, :, ix2v, :]
utt.assert_allclose(rval, aval)
def test_adv_subtensor_w_slice_and_matrix(self):
subt = self.ft4[:, 0:1, self.ix2, :]
f = theano.function([self.ft4, self.ix2], subt, mode=self.mode)
ft4v = numpy.random.random((2, 3, 4, 5)).astype('float32')
ix2v = numpy.asarray([[0, 1], [1, 0]])
aval = f(ft4v, ix2v)
rval = ft4v[:, 0:1, ix2v, :]
utt.assert_allclose(rval, aval)
def test_adv_subtensor_w_matrix_and_int(self):
subt = self.ft4[:, :, self.ix2, 0]
f = theano.function([self.ft4, self.ix2], subt, mode=self.mode)
ft4v = numpy.random.random((2, 3, 4, 5)).astype('float32')
ix2v = numpy.asarray([[0, 1], [1, 0]])
aval = f(ft4v, ix2v)
rval = ft4v[:, :, ix2v, 0]
utt.assert_allclose(rval, aval)
def test_adv_subtensor_w_matrix_and_none(self):
subt = self.ft4[:, :, self.ix2, None, :]
f = theano.function([self.ft4, self.ix2], subt, mode=self.mode)
ft4v = numpy.random.random((2, 3, 4, 5)).astype('float32')
ix2v = numpy.asarray([[0, 1], [1, 0]])
aval = f(ft4v, ix2v)
rval = ft4v[:, :, ix2v, None, :]
utt.assert_allclose(rval, aval)
def test_inc_adv_subtensor_w_2vec(self):
if inplace_increment is None:
raise inplace_increment_missing
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论