提交 fe10a775 authored 作者: Samira Ebrahimi Kahou's avatar Samira Ebrahimi Kahou

changed check_stack_trace to first check for ops or tuple of ops before __call__…

changed check_stack_trace to first check for ops or tuple of ops before __call__ attribute, because classes have a call attribute (constructor).
上级 71d08550
...@@ -2667,7 +2667,8 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -2667,7 +2667,8 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
# if ops_to_check is a string # if ops_to_check is a string
if isinstance(ops_to_check, string_types): if isinstance(ops_to_check, string_types):
if ops_to_check == 'last': if ops_to_check == 'last':
apply_nodes_to_check = [fgraph.outputs[i].owner for i in range(len(fgraph.outputs))] apply_nodes_to_check = [fgraph.outputs[i].owner for i in range(
len(fgraph.outputs))]
elif ops_to_check == 'all': elif ops_to_check == 'all':
apply_nodes_to_check = fgraph.apply_nodes apply_nodes_to_check = fgraph.apply_nodes
else: else:
...@@ -2712,6 +2713,13 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -2712,6 +2713,13 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
else: else:
raise ValueError('The string bug_print is not recognised') raise ValueError('The string bug_print is not recognised')
elif hasattr(ops_to_check, '__call__'): # if ops_to_check is a function
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if ops_to_check(node)]
else: # if ops_to_check is an op or a list of ops
raise ValueError('The value of ops_to_check is not supported')
for node in apply_nodes_to_check: for node in apply_nodes_to_check:
for output in node.outputs: for output in node.outputs:
if (not hasattr(output.tag, 'trace') or if (not hasattr(output.tag, 'trace') or
......
...@@ -2125,7 +2125,7 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -2125,7 +2125,7 @@ class test_local_subtensor_merge(unittest.TestCase):
'local_subtensor_merge')) 'local_subtensor_merge'))
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(f, ops_to_check='last')) self.assertTrue(check_stack_trace(f, ops_to_check=Subtensor))
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len([t for t in topo assert len([t for t in topo
...@@ -5929,9 +5929,7 @@ def test_local_useless_split(): ...@@ -5929,9 +5929,7 @@ def test_local_useless_split():
# Check that stacktraces have been copied over properly # Check that stacktraces have been copied over properly
assert check_stack_trace(f_opt, ops_to_check='all') assert check_stack_trace(f_opt, ops_to_check='all')
assert len(f_opt.outputs[0].variable.tag.trace) > 0
assert check_stack_trace(f_nonopt, ops_to_check='all') assert check_stack_trace(f_nonopt, ops_to_check='all')
assert len(f_nonopt.outputs[0].variable.tag.trace) > 0
def test_local_flatten_lift(): def test_local_flatten_lift():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论