提交 780813c7 authored 作者: lamblin's avatar lamblin

Merge pull request #578 from pascanur/fix_subtensor_problem

making the subtensor merge optimization work
......@@ -1803,7 +1803,9 @@ def local_subtensor_merge(node):
merged_slices = []
pos_2 = 0
for pos_1, slice1 in enumerate(slices1):
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if type(slice1) is slice:
merged_slices.append(
merge_two_slices(slice1,
......@@ -1813,8 +1815,14 @@ def local_subtensor_merge(node):
pos_2 += 1
else:
merged_slices.append(slice1)
pos_1 += 1
if pos_2 < len(slices2):
merged_slices += slices2[pos_2:]
else:
merged_slices += slices1[pos_1:]
merged_slices += slices2[pos_2:]
subtens = T.Subtensor(merged_slices)
sl_ins = T.Subtensor.collapse(
merged_slices,
......
......@@ -1822,6 +1822,35 @@ class test_local_subtensor_merge(unittest.TestCase):
val = fun(data)
assert val == data[7:1:-1][0]
def test_const5(self):
# Bug reported by Graham
data = self.rng.uniform(size=(8,8,8)).astype(theano.config.floatX)
x = theano.tensor.tensor3('x')
# test 1)
y = x[3:6,2:6,1:7][1]
fun = theano.function([x], y)
val = fun(data)
assert numpy.all(val == data[3:6,2:6,1:7][1])
assert len([n for n in fun.maker.env.toposort()
if isinstance(n.op, theano.tensor.basic.Subtensor)]) == 1
# test 2)
y = x[2,3][1]
fun = theano.function([x], y)
val = fun(data)
assert numpy.all(val == data[2,3][1])
assert len([n for n in fun.maker.env.toposort()
if isinstance(n.op, theano.tensor.basic.Subtensor)]) == 1
# test 3)
y = x[3:6,2,1:7][1]
fun = theano.function([x], y)
val = fun(data)
assert numpy.all(val == data[3:6,2,1:7][1])
assert len([n for n in fun.maker.env.toposort()
if isinstance(n.op, theano.tensor.basic.Subtensor)]) == 1
def test_scalar6(self):
# General case with one slice and one index
# var[b:e:s][i]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论