提交 8dc67eae authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in local_dimshuffle_subtensor rewrite

上级 4efbd193
...@@ -174,11 +174,15 @@ def local_dimshuffle_alloc(fgraph, node): ...@@ -174,11 +174,15 @@ def local_dimshuffle_alloc(fgraph, node):
def local_dimshuffle_subtensor(fgraph, node): def local_dimshuffle_subtensor(fgraph, node):
"""If a subtensor is inside a dimshuffle which only drop """If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the broadcastable dimensions, scrap the dimshuffle and index the
subtensor with 0 subtensor in a way that avoids the degenerate dimension
x[i:j, :, k:l].dimshuffle(0, 2) => x[i:j, :, k:l].dimshuffle(0, 2) =>
x[i:j, 0, k:l] if x.broadcastable == (False, True, False) x[i:j, 0, k:l] if x.broadcastable == (False, True, False)
x[i:j, k:l, :].dimshuffle(0, 2) => x[i:j, k, :]
x[i:j, k:, :].dimshuffle(0, 2) => x[i:j, k, :]
x[i:j, :l, :].dimshuffle(0, 2) => x[i:j, 0, :]
""" """
if isinstance(node.op, DimShuffle) and node.inputs[0].owner: if isinstance(node.op, DimShuffle) and node.inputs[0].owner:
# the dimshuffle can only drop dimensions (cannot reshape nor add 'x') # the dimshuffle can only drop dimensions (cannot reshape nor add 'x')
...@@ -217,24 +221,40 @@ def local_dimshuffle_subtensor(fgraph, node): ...@@ -217,24 +221,40 @@ def local_dimshuffle_subtensor(fgraph, node):
new_idx_list = list(input_.owner.op.idx_list) new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]] new_inputs = [input_.owner.inputs[0]]
zero = constant(0) zero = constant(0)
slice_attr_list = ["start", "stop", "step"]
j = 0 j = 0
slice_i = -1 slice_i = -1
subtensor_removed_dims = 0 subtensor_removed_dims = 0
for i, idx in enumerate(input_.owner.op.idx_list): for i, idx in enumerate(input_.owner.op.idx_list):
if isinstance(idx, slice): if isinstance(idx, slice):
past_j = j
slice_i += 1 slice_i += 1
for slice_attr in slice_attr_list: if slice_i in missing_dims:
if getattr(idx, slice_attr) is not None: # Missing dim is a slice(None), remove by indexing by 0
new_inputs += [input_.owner.inputs[1 + j]] if idx == slice(None):
j += 1 new_idx_list[i] = zero
# if past_j == j indicates a slice(None, None, None), new_inputs += [zero]
# that's where we want to index with 0 if it is also at # Missing dim is an ordinary slice with known output dim length of 1
# the same spot of a missing dim # Remove by indexing by start
if past_j == j and slice_i in missing_dims: else:
new_idx_list[i] = zero if idx.start is None:
new_inputs += [zero] start = zero
else:
start = input_.owner.inputs[1 + j]
j += 1
new_idx_list[i] = start
new_inputs += [start]
# Ignore useless stop and step input if there is one
for slice_attr in ("stop", "step"):
if getattr(idx, slice_attr) is not None:
j += 1
# Keep non-dropped slice inputs
else:
for slice_attr in ("start", "stop", "step"):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# Keep non-dropped non-slice inputs
else: else:
new_inputs += [input_.owner.inputs[1 + j]] new_inputs += [input_.owner.inputs[1 + j]]
j += 1 j += 1
......
...@@ -214,3 +214,11 @@ def test_local_dimshuffle_subtensor(): ...@@ -214,3 +214,11 @@ def test_local_dimshuffle_subtensor():
assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval( assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval(
{x: np.ones((5, 1, 6, 7))} {x: np.ones((5, 1, 6, 7))}
).shape == (5, 3, 7) ).shape == (5, 3, 7)
# Test dropped sliced dimensions
x = matrix("x", shape=(5, 4), dtype="float64")
assert x[2:3, :-2].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (2,)
assert x[:1, 0:3].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,)
assert x[-1:, :].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (4,)
assert x[4:3:-1, 1:].dimshuffle(1).eval({x: np.ones(x.type.shape)}).shape == (3,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论