提交 67017a66 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in `local_reshape_to_dimshuffle`

上级 b12dc30a
...@@ -966,16 +966,15 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -966,16 +966,15 @@ def local_reshape_to_dimshuffle(fgraph, node):
inp, output_shape = node.inputs inp, output_shape = node.inputs
[output] = node.outputs [output] = node.outputs
# Remove any broadcastable dimensions from the input
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
# Trivial case, all dimensions of input/output are known to be broadcastable: # Trivial case, all dimensions of input/output are known to be broadcastable:
# there's nothing to reshape # there's nothing to reshape
if all(inp.type.broadcastable) or all(output.type.broadcastable): if all(inp.type.broadcastable) or all(output.type.broadcastable):
squeeze_axes = tuple(range(inp.type.ndim))
new_output_shape = [] new_output_shape = []
expand_axes = tuple(range(output.type.ndim)) expand_axes = tuple(range(output.type.ndim))
else: else:
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
unpacked_shape = _unpack_shape_vector(output_shape) unpacked_shape = _unpack_shape_vector(output_shape)
new_output_shape = [] new_output_shape = []
expand_axes = [] expand_axes = []
......
...@@ -445,6 +445,15 @@ class TestLocalReshapeToDimshuffle: ...@@ -445,6 +445,15 @@ class TestLocalReshapeToDimshuffle:
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt")) new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False) assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
def test_reshape_implies_size_1_input(self):
x = pt.matrix("x", shape=(None, None))
out = pt.reshape(x, (1, 1, 1))
new_out = rewrite_graph(out, include=("canonicalize",))
assert equal_computations(
[new_out], [x.dimshuffle("x", "x", "x")], strict_dtype=False
)
def test_expand_dims_squeeze_reshape_fusion(): def test_expand_dims_squeeze_reshape_fusion():
x = pt.tensor("x", shape=(1, 9)) x = pt.tensor("x", shape=(1, 9))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论