提交 fa84dc7e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update tests of local_subtensor_merge

- Simplify and move test case for the bug reported by Razvan - Add more general test of merging tensor[slice][index]
上级 e87df287
...@@ -2508,23 +2508,6 @@ class T_subtensor(unittest.TestCase): ...@@ -2508,23 +2508,6 @@ class T_subtensor(unittest.TestCase):
assert numpy.all(t_out == v_out) assert numpy.all(t_out == v_out)
assert numpy.all(t_out.shape == v_out.shape) assert numpy.all(t_out.shape == v_out.shape)
def test_merge_subtensor(self):
# Bug reported by Razvan
data = numpy.asarray(numpy.arange(8),
dtype = theano.config.floatX)
x = theano.tensor.vector()
y1 = x[2:][::-1]
y2 = x[:-2][::-1]
length = theano.tensor.minimum(y1.shape[0], y2.shape[0])
y1 = y1[:length]
y2 = y2[:length]
t = theano.shared(numpy.int64(0))
fun = theano.function([x], [y1[t], y2[t]])
val0, val1 = fun(data)
assert val0 == data[2:][::-1][0]
assert val1 == data[:-2][::-1][0]
def grad_list_(self, idxs, data): def grad_list_(self, idxs, data):
n = self.shared(data) n = self.shared(data)
......
...@@ -1742,7 +1742,8 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1742,7 +1742,8 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_scalar5(self): def test_scalar5(self):
# var[int1:][:int2] # General case with two real slices
# var[b1:e1:s1][b2:e2:s2]
x = tensor.matrix('x') x = tensor.matrix('x')
b1 = tensor.iscalar('b1') b1 = tensor.iscalar('b1')
e1 = tensor.iscalar('e1') e1 = tensor.iscalar('e1')
...@@ -1777,6 +1778,59 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1777,6 +1778,59 @@ class test_local_subtensor_merge(unittest.TestCase):
for s2 in s2r: for s2 in s2r:
f(x_val, b1,e1,s1,b2,e2,s2) f(x_val, b1,e1,s1,b2,e2,s2)
def test_const4(self):
# Bug reported by Razvan
data = numpy.asarray(numpy.arange(8),
dtype = theano.config.floatX)
x = theano.tensor.vector('x')
y = x[7:1:-1]
t = theano.shared(numpy.int64(0))
fun = theano.function([x], y[t])
val = fun(data)
assert val == data[7:1:-1][0]
def test_scalar6(self):
# General case with one slice and one index
# var[b:e:s][i]
x = tensor.matrix('x')
b = tensor.iscalar('b')
e = tensor.iscalar('e')
s = tensor.iscalar('s')
i = tensor.iscalar('i')
f = function([x,b,e,s,i], x[b:e:s][i], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, tensor.Subtensor)]
assert len([t for t in topo if isinstance(t.op, tensor.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
b_r = self.rng.permutation(range(-8,8))[:2]
e_r = self.rng.permutation(range(-8,8))[:2]
i_r = self.rng.permutation(range(-8,8))[:2]
s_r = self.rng.permutation([-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7])[:2]
for x_s in self.x_shapes:
x_val = self.rng.uniform(size=x_s).astype(config.floatX)
for b_v in b_r:
for e_v in e_r:
for s_v in s_r:
for i_v in i_r:
# The index could be out of bounds
# In that case, an Exception should be raised,
# otherwise, we let DebugMode check f
try:
x_val[b_v:e_v:s_v][i_v]
except IndexError:
self.assertRaises(IndexError,
f, x_val, b_v, e_v, s_v, i_v)
else:
# Executed if the "try" clause did not
# raise an exception
f(x_val, b_v, e_v, s_v, i_v)
class Test_alloc_zero(unittest.TestCase): class Test_alloc_zero(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论