提交 f62366bc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Finish test

上级 3327aadd
...@@ -1669,11 +1669,18 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1669,11 +1669,18 @@ class TestAdvancedSubtensor(unittest.TestCase):
def test_adv_grouped(self): def test_adv_grouped(self):
# Reported in https://github.com/Theano/Theano/issues/6152 # Reported in https://github.com/Theano/Theano/issues/6152
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
var = self.shared(rng.rand(3, 63, 4).astype(config.floatX)) var_v = rng.rand(3, 63, 4).astype(config.floatX)
idx1 = self.shared(rng.randint(0, 61, size=(5, 4)).astype('int32')) var = self.shared(var_v)
idx1_v = rng.randint(0, 61, size=(5, 4)).astype('int32')
idx1 = self.shared(idx1_v)
idx2 = tensor.arange(4) idx2 = tensor.arange(4)
out = var[:, idx1, idx2] out = var[:, idx1, idx2]
f = theano.function([], out, mode=self.mode) f = theano.function([], out, mode=self.mode)
out_v = f()
assert out_v.shape == (3, 5, 4)
out_np = var_v[:, idx1_v, np.arange(4)]
utt.assert_allclose(out_v, out_np)
def test_grad(self): def test_grad(self):
ones = np.ones((1, 3), dtype=self.dtype) ones = np.ones((1, 3), dtype=self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论