提交 ca9499d5 authored 作者: Olivier Mastropietro's avatar Olivier Mastropietro

Did not check if the slice(None, None, None) was also part of a legit dropped dimensions

上级 8c654cf1
......@@ -220,16 +220,19 @@ def local_dimshuffle_subtensor(node):
zero = T.constant(0)
slice_attr_list = ['start','stop','step']
j = 0
slice_i = -1
for idx in input_.owner.op.idx_list:
if isinstance(idx, slice):
past_j = j
slice_i += 1
for slice_attr in slice_attr_list:
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1+j]]
j += 1
if past_j == j:
# here is a slice(None, None, None), that's where
# we want to index with 0.
# if past_j == j indicates a slice(None, None, None), that's where
# we want to index with 0 if it is also at the same
# spot of a missing dim
if past_j == j and slice_i in missing_dims:
new_idx_list[j] = zero
new_inputs += [zero]
else:
......
......@@ -178,7 +178,7 @@ def test_local_dimshuffle_subtensor():
x = tensor.patternbroadcast(x, (False, True, False, False))
i = tensor.iscalar('i')
out = x[i, :, 10:30, ::-1].dimshuffle(1,2)
out = x[:, :, 10:30, ::i].dimshuffle(0,2,3)
g = FunctionGraph([x,i], [out])
dimshuffle_subtensor(g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论