提交 8dc67eae authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in local_dimshuffle_subtensor rewrite

上级 4efbd193
......@@ -174,11 +174,15 @@ def local_dimshuffle_alloc(fgraph, node):
def local_dimshuffle_subtensor(fgraph, node):
"""If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the
subtensor with 0
subtensor in a way that avoids the degenerate dimension
x[i:j, :, k:l].dimshuffle(0, 2) =>
x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
x[i:j, k:l, :].dimshuffle(0, 2) => x[i:j, k, :]
x[i:j, k:, :].dimshuffle(0, 2) => x[i:j, k, :]
x[i:j, :l, :].dimshuffle(0, 2) => x[i:j, 0, :]
"""
if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
......@@ -217,24 +221,40 @@ def local_dimshuffle_subtensor(fgraph, node):
new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]]
zero = constant(0)
slice_attr_list = ["start", "stop", "step"]
j = 0
slice_i = -1
subtensor_removed_dims = 0
for i, idx in enumerate(input_.owner.op.idx_list):
if isinstance(idx, slice):
past_j = j
slice_i += 1
for slice_attr in slice_attr_list:
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# if past_j == j indicates a slice(None, None, None),
# that's where we want to index with 0 if it is also at
# the same spot of a missing dim
if past_j == j and slice_i in missing_dims:
new_idx_list[i] = zero
new_inputs += [zero]
if slice_i in missing_dims:
# Missing dim is a slice(None), remove by indexing by 0
if idx == slice(None):
new_idx_list[i] = zero
new_inputs += [zero]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
else:
if idx.start is None:
start = zero
else:
start = input_.owner.inputs[1 + j]
j += 1
new_idx_list[i] = start
new_inputs += [start]
# Ignore useless stop and step input if there is one
for slice_attr in ("stop", "step"):
if getattr(idx, slice_attr) is not None:
j += 1
# Keep non-dropped slice inputs
else:
for slice_attr in ("start", "stop", "step"):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# Keep non-dropped non-slice inputs
else:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
......
......@@ -214,3 +214,11 @@ def test_local_dimshuffle_subtensor():
assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval(
{x: np.ones((5, 1, 6, 7))}
).shape == (5, 3, 7)
# Test dropped sliced dimensions
x = matrix("x", shape=(5, 4), dtype="float64")
assert x[2:3, :-2].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (2,)
assert x[:1, 0:3].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,)
assert x[-1:, :].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (4,)
assert x[4:3:-1, 1:].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论