提交 780813c7 authored 作者: lamblin's avatar lamblin

Merge pull request #578 from pascanur/fix_subtensor_problem

making the subtensor merge optimization work
...@@ -1803,7 +1803,9 @@ def local_subtensor_merge(node): ...@@ -1803,7 +1803,9 @@ def local_subtensor_merge(node):
merged_slices = [] merged_slices = []
pos_2 = 0 pos_2 = 0
for pos_1, slice1 in enumerate(slices1): pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if type(slice1) is slice: if type(slice1) is slice:
merged_slices.append( merged_slices.append(
merge_two_slices(slice1, merge_two_slices(slice1,
...@@ -1813,8 +1815,14 @@ def local_subtensor_merge(node): ...@@ -1813,8 +1815,14 @@ def local_subtensor_merge(node):
pos_2 += 1 pos_2 += 1
else: else:
merged_slices.append(slice1) merged_slices.append(slice1)
pos_1 += 1
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
merged_slices += slices1[pos_1:]
merged_slices += slices2[pos_2:]
subtens = T.Subtensor(merged_slices) subtens = T.Subtensor(merged_slices)
sl_ins = T.Subtensor.collapse( sl_ins = T.Subtensor.collapse(
merged_slices, merged_slices,
......
...@@ -1822,6 +1822,35 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1822,6 +1822,35 @@ class test_local_subtensor_merge(unittest.TestCase):
val = fun(data) val = fun(data)
assert val == data[7:1:-1][0] assert val == data[7:1:-1][0]
def test_const5(self):
# Bug reported by Graham
data = self.rng.uniform(size=(8,8,8)).astype(theano.config.floatX)
x = theano.tensor.tensor3('x')
# test 1)
y = x[3:6,2:6,1:7][1]
fun = theano.function([x], y)
val = fun(data)
assert numpy.all(val == data[3:6,2:6,1:7][1])
assert len([n for n in fun.maker.env.toposort()
if isinstance(n.op, theano.tensor.basic.Subtensor)]) == 1
# test 2)
y = x[2,3][1]
fun = theano.function([x], y)
val = fun(data)
assert numpy.all(val == data[2,3][1])
assert len([n for n in fun.maker.env.toposort()
if isinstance(n.op, theano.tensor.basic.Subtensor)]) == 1
# test 3)
y = x[3:6,2,1:7][1]
fun = theano.function([x], y)
val = fun(data)
assert numpy.all(val == data[3:6,2,1:7][1])
assert len([n for n in fun.maker.env.toposort()
if isinstance(n.op, theano.tensor.basic.Subtensor)]) == 1
def test_scalar6(self): def test_scalar6(self):
# General case with one slice and one index # General case with one slice and one index
# var[b:e:s][i] # var[b:e:s][i]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论