提交 83d458b5 authored 作者: Vincent Michalski's avatar Vincent Michalski

local_useless_split

上级 c69c657b
...@@ -4022,11 +4022,13 @@ def local_useless_split(node): ...@@ -4022,11 +4022,13 @@ def local_useless_split(node):
if node.op.len_splits == 1: if node.op.len_splits == 1:
x, axis, splits = node.inputs x, axis, splits = node.inputs
out = assert_op(x, T.eq(splits.shape[0], 1)) out = assert_op(x, T.eq(splits.shape[0], 1))
out = assert_op(out, T.eq(x.shape[axis], splits[0]))
# Copy over stacktrace from previous output node. # Copy over stacktrace from previous output node.
copy_stack_trace(node.outputs, out) copy_stack_trace(node.outputs, out)
return [out] out2 = assert_op(out, T.eq(x.shape[axis], splits[0]))
# Copy over stacktrace from previous output node.
copy_stack_trace(out, out2)
return [out2]
################ ################
......
...@@ -6062,7 +6062,7 @@ def test_local_useless_split(): ...@@ -6062,7 +6062,7 @@ def test_local_useless_split():
# Check if there are use cases that are not covered here # Check if there are use cases that are not covered here
# and if the line below is necessary and correct (See issue #4421) # and if the line below is necessary and correct (See issue #4421)
# assert check_stack_trace(f_opt, ops_to_check=[Assert]) assert check_stack_trace(f_opt, ops_to_check=[Assert])
assert check_stack_trace(f_nonopt, ops_to_check='all') assert check_stack_trace(f_nonopt, ops_to_check='all')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论