提交 169d0be8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #345 from lamblin/fix_subtensor_merge

Fix subtensor merge
...@@ -1676,7 +1676,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -1676,7 +1676,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
# case it was in reverse we need to realize that we do not want # case it was in reverse we need to realize that we do not want
# the k-th element from sl.start but the k-th element from # the k-th element from sl.start but the k-th element from
# sl.stop backwards # sl.stop backwards
n_val = sl1.stop - sl1.start - 1 - sl2 * sl1.step n_val = sl1.stop - 1 - sl2 * sl1.step
# we need to pick either n_val or p_val and then follow same # we need to pick either n_val or p_val and then follow same
# steps as above for covering the index error cases # steps as above for covering the index error cases
val = T.switch(T.lt(reverse1, 0), n_val, p_val) val = T.switch(T.lt(reverse1, 0), n_val, p_val)
......
...@@ -2508,23 +2508,6 @@ class T_subtensor(unittest.TestCase): ...@@ -2508,23 +2508,6 @@ class T_subtensor(unittest.TestCase):
assert numpy.all(t_out == v_out) assert numpy.all(t_out == v_out)
assert numpy.all(t_out.shape == v_out.shape) assert numpy.all(t_out.shape == v_out.shape)
def test_merge_subtensor(self):
# Bug reported by Razvan
data = numpy.asarray(numpy.arange(8),
dtype = theano.config.floatX)
x = theano.tensor.vector()
y1 = x[2:][::-1]
y2 = x[:-2][::-1]
length = theano.tensor.minimum(y1.shape[0], y2.shape[0])
y1 = y1[:length]
y2 = y2[:length]
t = theano.shared(numpy.int64(0))
fun = theano.function([x], [y1[t], y2[t]])
val0, val1 = fun(data)
assert val0 == data[2:][::-1][0]
assert val1 == data[:-2][::-1][0]
def grad_list_(self, idxs, data): def grad_list_(self, idxs, data):
n = self.shared(data) n = self.shared(data)
......
...@@ -1742,7 +1742,8 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1742,7 +1742,8 @@ class test_local_subtensor_merge(unittest.TestCase):
def test_scalar5(self): def test_scalar5(self):
# var[int1:][:int2] # General case with two real slices
# var[b1:e1:s1][b2:e2:s2]
x = tensor.matrix('x') x = tensor.matrix('x')
b1 = tensor.iscalar('b1') b1 = tensor.iscalar('b1')
e1 = tensor.iscalar('e1') e1 = tensor.iscalar('e1')
...@@ -1777,6 +1778,66 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1777,6 +1778,66 @@ class test_local_subtensor_merge(unittest.TestCase):
for s2 in s2r: for s2 in s2r:
f(x_val, b1,e1,s1,b2,e2,s2) f(x_val, b1,e1,s1,b2,e2,s2)
def test_const4(self):
# Bug reported by Razvan
data = numpy.asarray(numpy.arange(8),
dtype = theano.config.floatX)
x = theano.tensor.vector('x')
y = x[7:1:-1]
t = theano.shared(numpy.int64(0))
fun = theano.function([x], y[t])
val = fun(data)
assert val == data[7:1:-1][0]
def test_scalar6(self):
# General case with one slice and one index
# var[b:e:s][i]
x = tensor.matrix('x')
b = tensor.iscalar('b')
e = tensor.iscalar('e')
s = tensor.iscalar('s')
i = tensor.iscalar('i')
f = function([x,b,e,s,i], x[b:e:s][i], mode=mode_opt)
#theano.printing.debugprint(f, print_type=True)
topo=f.maker.env.toposort()
#print [t for t in topo if isinstance(t.op, tensor.Subtensor)]
assert len([t for t in topo if isinstance(t.op, tensor.Subtensor)]) == 1
#print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
b_r = self.rng.permutation(range(-4,4))[:3]
e_r = self.rng.permutation(range(-4,4))[:3]
i_r = self.rng.permutation(range(-4,4))[:3]
s_r = self.rng.permutation([-3,-2,-1,1,2,3])[:3]
for x_s in self.x_shapes:
n_index_err = 0
n_ok = 0
x_val = self.rng.uniform(size=x_s).astype(config.floatX)
for b_v in b_r:
for e_v in e_r:
for s_v in s_r:
for i_v in i_r:
# The index could be out of bounds
# In that case, an Exception should be raised,
# otherwise, we let DebugMode check f
try:
x_val[b_v:e_v:s_v][i_v]
except IndexError:
n_index_err += 1
self.assertRaises(IndexError,
f, x_val, b_v, e_v, s_v, i_v)
else:
# Executed if the "try" clause did not raise
# any exception
n_ok += 1
f(x_val, b_v, e_v, s_v, i_v)
print 'shape: %s' % (x_s,)
print '%% OK: %f' % (float(n_ok) * 100 / (n_ok + n_index_err))
class Test_alloc_zero(unittest.TestCase): class Test_alloc_zero(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论