提交 547e002b authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4678 from lamblin/fix_advincsubtensor1_reshape

Fix inc/set_subtensor when indexing with one non-vector
......@@ -1170,7 +1170,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y = alloc(y, *[x.shape[i] for i in xrange(x.ndim)])
flattened_y = expanded_y.flatten(inner_x.ndim)
flattened_y = expanded_y.reshape(inner_x.shape)
else:
flattened_y = y
......
......@@ -1221,6 +1221,8 @@ class TestIncSubtensor1(unittest.TestCase):
# also tests set_subtensor
def setUp(self):
self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.s = tensor.iscalar()
self.v = tensor.fvector()
self.m = tensor.dmatrix()
......@@ -1267,6 +1269,21 @@ class TestIncSubtensor1(unittest.TestCase):
self.assertRaises(TypeError,
lambda: inc_subtensor(self.v[self.adv1q], fmatrix()))
def test_matrix_idx(self):
idx = tensor.lmatrix()
a = self.m[idx]
a2 = inc_subtensor(a, a)
f = theano.function([self.m, idx], a2)
mval = self.rng.random_sample((4, 10))
idxval = numpy.array([[1, 2], [3, 2]])
a2val = f(mval, idxval)
utt.assert_allclose(a2val[0], mval[0])
utt.assert_allclose(a2val[1], mval[1] * 2)
utt.assert_allclose(a2val[2], mval[2] * 3)
utt.assert_allclose(a2val[3], mval[3] * 2)
inplace_increment_missing = SkipTest(
"inc_subtensor with advanced indexing not enabled. "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论