提交 8e5ad2e1 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a test for AdvancedSubtensor1 with noncontiguous index input.

上级 4ec4edaa
......@@ -529,6 +529,15 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
utt.verify_grad(lambda m: m[idx],
[data])
def test_noncontiguous_idx(self):
data = rand(4, 2, 3)
idx = [2, 2, 0, 0, 1, 1]
n = self.shared(data)
t = n[self.shared(numpy.asarray(idx))[::2]]
self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor1))
val = self.eval_output_and_check(t, op_type=self.adv_sub1, length=2)
utt.assert_allclose(data[idx[::2]], val)
def test_err_invalid_list(self):
n = self.shared(numpy.asarray(5, dtype=self.dtype))
self.assertRaises(TypeError, n.__getitem__, [0, 0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论