提交 a8f91938 authored 作者: Frederic Bastien's avatar Frederic Bastien

Bug/crash fix. crash fix reported, but it is not sure it would always return a crash.

上级 091f1a29
...@@ -217,7 +217,7 @@ def local_dimshuffle_subtensor(node): ...@@ -217,7 +217,7 @@ def local_dimshuffle_subtensor(node):
j = 0 j = 0
slice_i = -1 slice_i = -1
subtensor_removed_dims = 0 subtensor_removed_dims = 0
for idx in input_.owner.op.idx_list: for i, idx in enumerate(input_.owner.op.idx_list):
if isinstance(idx, slice): if isinstance(idx, slice):
past_j = j past_j = j
slice_i += 1 slice_i += 1
...@@ -229,7 +229,7 @@ def local_dimshuffle_subtensor(node): ...@@ -229,7 +229,7 @@ def local_dimshuffle_subtensor(node):
# that's where we want to index with 0 if it is also at # that's where we want to index with 0 if it is also at
# the same spot of a missing dim # the same spot of a missing dim
if past_j == j and slice_i in missing_dims: if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero new_idx_list[i] = zero
new_inputs += [zero] new_inputs += [zero]
else: else:
new_inputs += [input_.owner.inputs[1 + j]] new_inputs += [input_.owner.inputs[1 + j]]
......
...@@ -207,3 +207,9 @@ def test_local_dimshuffle_subtensor(): ...@@ -207,3 +207,9 @@ def test_local_dimshuffle_subtensor():
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert any([not isinstance(x, DimShuffle) for x in topo]) assert any([not isinstance(x, DimShuffle) for x in topo])
assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4,) assert f(np.random.rand(5, 1, 4, 1), 2).shape == (4,)
# Test a corner case that had Theano return a bug.
x = tensor.dtensor4('x')
x = tensor.patternbroadcast(x, (False, True, False, False))
assert x[:,:, 0:3, ::-1].dimshuffle(0,2,3).eval({x: np.ones((5, 1, 6, 7))}).shape == (5, 3, 7)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论