提交 812c1513 authored 作者: AdeB's avatar AdeB

Allows the first argument of check_stack_trace to be a fgraph

上级 8bedd35e
...@@ -2615,15 +2615,16 @@ def copy_stack_trace(from_var, to_var): ...@@ -2615,15 +2615,16 @@ def copy_stack_trace(from_var, to_var):
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def check_stack_trace(compiled_f, ops_to_check='last'): def check_stack_trace(f_or_fgraph, ops_to_check='last'):
""" """
This function checks if the outputs of specific ops of a compiled graph This function checks if the outputs of specific ops of a compiled graph
have a stack. have a stack.
Parameters Parameters
---------- ----------
compiled_f: theano.compile.function_module.Function f_or_fgraph: theano.compile.function_module.Function or
The compiled function to be analysed. theano.gof.fg.FunctionGraph
The compiled function or the function graph to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string. ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string.
- if ops_to_check is an op or a tuple of ops, the function will check - if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack that all the outputs of their occurrences in the graph have a stack
...@@ -2632,18 +2633,23 @@ def check_stack_trace(compiled_f, ops_to_check='last'): ...@@ -2632,18 +2633,23 @@ def check_stack_trace(compiled_f, ops_to_check='last'):
'last' will check only the last op of the graph while 'all' will 'last' will check only the last op of the graph while 'all' will
check all the ops of the graph. check all the ops of the graph.
""" """
graph = compiled_f.maker.fgraph if isinstance(f_or_fgraph, theano.compile.function_module.Function):
fgraph = f_or_fgraph.maker.fgraph
elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph):
fgraph = f_or_fgraph
else:
raise ValueError('The type of f_f_or_fgraph is not supported')
if isinstance(ops_to_check, basestring): if isinstance(ops_to_check, string_types):
if ops_to_check == 'last': if ops_to_check == 'last':
apply_nodes_to_check = [compiled_f.maker.fgraph.outputs[0].owner] apply_nodes_to_check = [fgraph.outputs[0].owner]
elif ops_to_check == 'all': elif ops_to_check == 'all':
apply_nodes_to_check = graph.apply_nodes apply_nodes_to_check = fgraph.apply_nodes
else: else:
raise ValueError('The string ops_to_check is not recognised') raise ValueError('The string ops_to_check is not recognised')
else: else:
apply_nodes_to_check = [node for node in graph.apply_nodes apply_nodes_to_check = [node for node in fgraph.apply_nodes
if isinstance(node.op, ops_to_check)] if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check: if not apply_nodes_to_check:
warnings.warn('Provided ops are not in the graph') warnings.warn('Provided ops are not in the graph')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论