提交 ca9499d5 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Did not check if the slice(None, None, None) was also part of a legit dropped dimensions

上级 8c654cf1
...@@ -220,16 +220,19 @@ def local_dimshuffle_subtensor(node): ...@@ -220,16 +220,19 @@ def local_dimshuffle_subtensor(node):
zero = T.constant(0) zero = T.constant(0)
slice_attr_list = ['start','stop','step'] slice_attr_list = ['start','stop','step']
j = 0 j = 0
slice_i = -1
for idx in input_.owner.op.idx_list: for idx in input_.owner.op.idx_list:
if isinstance(idx, slice): if isinstance(idx, slice):
past_j = j past_j = j
slice_i += 1
for slice_attr in slice_attr_list: for slice_attr in slice_attr_list:
if getattr(idx, slice_attr) is not None: if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1+j]] new_inputs += [input_.owner.inputs[1+j]]
j += 1 j += 1
if past_j == j: # if past_j == j indicates a slice(None, None, None), that's where
# here is a slice(None, None, None), that's where # we want to index with 0 if it is also at the same
# we want to index with 0. # spot of a missing dim
if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero new_idx_list[j] = zero
new_inputs += [zero] new_inputs += [zero]
else: else:
......
...@@ -178,7 +178,7 @@ def test_local_dimshuffle_subtensor(): ...@@ -178,7 +178,7 @@ def test_local_dimshuffle_subtensor():
x = tensor.patternbroadcast(x, (False, True, False, False)) x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i') i = tensor.iscalar('i')
out = x[i, :, 10:30, ::-1].dimshuffle(1,2) out = x[:, :, 10:30, ::i].dimshuffle(0,2,3)
g = FunctionGraph([x,i], [out]) g = FunctionGraph([x,i], [out])
dimshuffle_subtensor(g) dimshuffle_subtensor(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论