提交 2f5739eb authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add test

上级 4ae31749
......@@ -1221,6 +1221,8 @@ class TestIncSubtensor1(unittest.TestCase):
# also tests set_subtensor
def setUp(self):
self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.s = tensor.iscalar()
self.v = tensor.fvector()
self.m = tensor.dmatrix()
......@@ -1267,6 +1269,21 @@ class TestIncSubtensor1(unittest.TestCase):
self.assertRaises(TypeError,
lambda: inc_subtensor(self.v[self.adv1q], fmatrix()))
def test_matrix_idx(self):
idx = tensor.lmatrix()
a = self.m[idx]
a2 = inc_subtensor(a, a)
f = theano.function([self.m, idx], a2)
mval = self.rng.random_sample((4, 10))
idxval = numpy.array([[1, 2], [3, 2]])
a2val = f(mval, idxval)
utt.assert_allclose(a2val[0], mval[0])
utt.assert_allclose(a2val[1], mval[1] * 2)
utt.assert_allclose(a2val[2], mval[2] * 3)
utt.assert_allclose(a2val[3], mval[3] * 2)
inplace_increment_missing = SkipTest(
"inc_subtensor with advanced indexing not enabled. "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论