提交 2f6c0084 authored 作者: AdeB's avatar AdeB

Add new arguments/options to the argument of the function

上级 812c1513
......@@ -2615,7 +2615,7 @@ def copy_stack_trace(from_var, to_var):
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def check_stack_trace(f_or_fgraph, ops_to_check='last'):
def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
"""
This function checks if the outputs of specific ops of a compiled graph
have a stack.
......@@ -2625,20 +2625,34 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last'):
f_or_fgraph: theano.compile.function_module.Function or
theano.gof.fg.FunctionGraph
The compiled function or the function graph to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string.
- if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack
trace.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string or a
function returning a boolean and taking as input a theano.gof.Op.
- if ops_to_check is a string, it should be either 'last' or 'all'.
'last' will check only the last op of the graph while 'all' will
check all the ops of the graph.
- if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack
trace.
- if ops_to_check is a function, it should take as input a
theano.gof.Op and return a boolean indicating if the input op should
be checked or not.
bug_print: string belonging to {'raise', 'warn', 'ignore'}
You can specify the behaviour of the function when the specified
ops_to_check are not in the graph of f_or_fgraph: it can either raise
an exception, write a warning or simply ignore it.
Returns
-------
boolean
True if the outputs of the specified ops have a stack, False otherwise.
"""
if isinstance(f_or_fgraph, theano.compile.function_module.Function):
fgraph = f_or_fgraph.maker.fgraph
elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph):
fgraph = f_or_fgraph
else:
raise ValueError('The type of f_f_or_fgraph is not supported')
raise ValueError('The type of f_or_fgraph is not supported')
if isinstance(ops_to_check, string_types):
if ops_to_check == 'last':
......@@ -2648,12 +2662,27 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last'):
else:
raise ValueError('The string ops_to_check is not recognised')
else:
elif hasattr(ops_to_check, '__call__'): # if ops_to_check is a function
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if ops_to_check(node)]
else: # if ops_to_check is an op or a list of ops
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check:
warnings.warn('Provided ops are not in the graph')
msg = 'Provided ops are not in the graph'
if bug_print == 'warn':
warnings.warn(msg)
elif bug_print == 'raise':
raise Exception(msg)
elif bug_print == 'ignore':
pass
else:
raise ValueError('The string bug_print is not recognised')
for node in apply_nodes_to_check:
for output in node.outputs:
assert hasattr(output.tag, 'trace')
if not hasattr(output.tag, 'trace'):
return False
return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论