提交 8bedd35e authored 作者: AdeB's avatar AdeB

Add helper function check_stack_trace to check if a stack trace is in a graph or not

上级 8c7511bf
......@@ -2613,3 +2613,41 @@ def copy_stack_trace(from_var, to_var):
# Copy over stack traces from from_var to each variable to
# to_var, including the stack_trace of the to_var before
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def check_stack_trace(compiled_f, ops_to_check='last'):
"""
This function checks if the outputs of specific ops of a compiled graph
have a stack.
Parameters
----------
compiled_f: theano.compile.function_module.Function
The compiled function to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string.
- if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack
trace.
- if ops_to_check is a string, it should be either 'last' or 'all'.
'last' will check only the last op of the graph while 'all' will
check all the ops of the graph.
"""
graph = compiled_f.maker.fgraph
if isinstance(ops_to_check, basestring):
if ops_to_check == 'last':
apply_nodes_to_check = [compiled_f.maker.fgraph.outputs[0].owner]
elif ops_to_check == 'all':
apply_nodes_to_check = graph.apply_nodes
else:
raise ValueError('The string ops_to_check is not recognised')
else:
apply_nodes_to_check = [node for node in graph.apply_nodes
if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check:
warnings.warn('Provided ops are not in the graph')
for node in apply_nodes_to_check:
for output in node.outputs:
assert hasattr(output.tag, 'trace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论