提交 9637468f authored 作者: Frederic's avatar Frederic

Added subtensor test with negative value.

上级 9a0c05a5
...@@ -1887,7 +1887,9 @@ class T_subtensor(unittest.TestCase): ...@@ -1887,7 +1887,9 @@ class T_subtensor(unittest.TestCase):
def test2_ok_range_finite(self): def test2_ok_range_finite(self):
n = self.shared(numpy.ones((3,4), dtype=self.dtype)*5) n = self.shared(numpy.ones((3,4), dtype=self.dtype)*5)
t = n[0:2,3] # Also check negative index
for idx in [(slice(0,2),3),((slice(0,2),-1)),(slice(0,2),-4)]:
t = n[idx]#l]#0:2,3]
self.assertTrue(isinstance(t.owner.op, Subtensor)) self.assertTrue(isinstance(t.owner.op, Subtensor))
f = inplace_func([], t, mode=self.mode) f = inplace_func([], t, mode=self.mode)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
...@@ -1896,7 +1898,7 @@ class T_subtensor(unittest.TestCase): ...@@ -1896,7 +1898,7 @@ class T_subtensor(unittest.TestCase):
assert isinstance(topo_[0].op, self.sub) assert isinstance(topo_[0].op, self.sub)
tval = f() tval = f()
self.assertTrue(tval.shape == (2,)) self.assertTrue(tval.shape == (2,))
self.assertTrue(tval[1] == 5.0) self.assertTrue(numpy.allclose(tval, n.get_value()[idx]))
def test1_err_invalid(self): def test1_err_invalid(self):
n = self.shared(numpy.ones(1, dtype=self.dtype)) n = self.shared(numpy.ones(1, dtype=self.dtype))
...@@ -1948,7 +1950,8 @@ class T_subtensor(unittest.TestCase): ...@@ -1948,7 +1950,8 @@ class T_subtensor(unittest.TestCase):
def test2_err_bounds0(self): def test2_err_bounds0(self):
n = self.shared(numpy.ones((2,3), dtype=self.dtype)*5) n = self.shared(numpy.ones((2,3), dtype=self.dtype)*5)
t = n[0,4] for idx in [(0,4),(0,-4)]:
t = n[idx]
self.assertTrue(isinstance(t.owner.op, Subtensor)) self.assertTrue(isinstance(t.owner.op, Subtensor))
# Silence expected warnings # Silence expected warnings
_logger = logging.getLogger('theano.gof.opt') _logger = logging.getLogger('theano.gof.opt')
...@@ -1962,6 +1965,7 @@ class T_subtensor(unittest.TestCase): ...@@ -1962,6 +1965,7 @@ class T_subtensor(unittest.TestCase):
pass pass
finally: finally:
_logger.setLevel(oldlevel) _logger.setLevel(oldlevel)
def test2_err_bounds1(self): def test2_err_bounds1(self):
n = self.shared((numpy.ones((2,3), dtype=self.dtype)*5)) n = self.shared((numpy.ones((2,3), dtype=self.dtype)*5))
t = n[4:5,2] t = n[4:5,2]
...@@ -2077,6 +2081,10 @@ class T_subtensor(unittest.TestCase): ...@@ -2077,6 +2081,10 @@ class T_subtensor(unittest.TestCase):
(numpy.random.rand(4,5), [2,3]), (numpy.random.rand(4,5), [2,3]),
(numpy.random.rand(4,2,3), [0,3]), (numpy.random.rand(4,2,3), [0,3]),
(numpy.random.rand(4,2,3), [3,3,1,1,2,2,0,0]), (numpy.random.rand(4,2,3), [3,3,1,1,2,2,0,0]),
(numpy.random.rand(4,2,3), [3,3,1,1,2,2,0,0,-1,-2,-3,-4]),
# Test 4 dims as gpu code use another algo in that case
# This new algo is not as much optimized for that case.
(numpy.random.rand(4,4,2,3), [3,3,1,1,2,2,0,0,-1,-2,-3,-4]),
# Test with TensorConstant index. # Test with TensorConstant index.
(numpy.random.rand(4,2,3), constant([3,3,1,1,2,2,0,0])), (numpy.random.rand(4,2,3), constant([3,3,1,1,2,2,0,0])),
]: ]:
...@@ -2119,16 +2127,18 @@ class T_subtensor(unittest.TestCase): ...@@ -2119,16 +2127,18 @@ class T_subtensor(unittest.TestCase):
def test_err_bound_list(self): def test_err_bound_list(self):
n = self.shared(numpy.ones((2,3),dtype=self.dtype)*5) n = self.shared(numpy.ones((2,3),dtype=self.dtype)*5)
t = n[[0,4]] l = lvector()
t = n[l]
# We test again AdvancedSubtensor1 as we transfer data to the cpu. # We test again AdvancedSubtensor1 as we transfer data to the cpu.
self.assertTrue(isinstance(t.owner.op, theano.tensor.basic.AdvancedSubtensor1)) self.assertTrue(isinstance(t.owner.op, theano.tensor.basic.AdvancedSubtensor1))
f = function([], t, mode=self.mode) f = function([l], t, mode=self.mode)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
topo_ = [node for node in topo if not isinstance(node.op, self.ignore_topo)] topo_ = [node for node in topo if not isinstance(node.op, self.ignore_topo)]
assert len(topo_)==1 assert len(topo_)==1
self.assertTrue(isinstance(topo_[0].op, self.adv_sub1)) self.assertTrue(isinstance(topo_[0].op, self.adv_sub1))
self.assertRaises(IndexError, f) for shp in [[0,4],[0,-3], [-10]]:
self.assertRaises(IndexError, f, shp)
def test_adv_sub1_broadcast(self): def test_adv_sub1_broadcast(self):
ones = numpy.ones((1,3), dtype=self.dtype) ones = numpy.ones((1,3), dtype=self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论