提交 2f6c0084 authored 作者: AdeB's avatar AdeB

Add new arguments/options to the argument of the function

上级 812c1513
...@@ -2615,7 +2615,7 @@ def copy_stack_trace(from_var, to_var): ...@@ -2615,7 +2615,7 @@ def copy_stack_trace(from_var, to_var):
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def check_stack_trace(f_or_fgraph, ops_to_check='last'): def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
""" """
This function checks if the outputs of specific ops of a compiled graph This function checks if the outputs of specific ops of a compiled graph
have a stack. have a stack.
...@@ -2625,20 +2625,34 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last'): ...@@ -2625,20 +2625,34 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last'):
f_or_fgraph: theano.compile.function_module.Function or f_or_fgraph: theano.compile.function_module.Function or
theano.gof.fg.FunctionGraph theano.gof.fg.FunctionGraph
The compiled function or the function graph to be analysed. The compiled function or the function graph to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string. ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string or a
- if ops_to_check is an op or a tuple of ops, the function will check function returning a boolean and taking as input a theano.gof.Op.
that all the outputs of their occurrences in the graph have a stack
trace.
- if ops_to_check is a string, it should be either 'last' or 'all'. - if ops_to_check is a string, it should be either 'last' or 'all'.
'last' will check only the last op of the graph while 'all' will 'last' will check only the last op of the graph while 'all' will
check all the ops of the graph. check all the ops of the graph.
- if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack
trace.
- if ops_to_check is a function, it should take as input a
theano.gof.Op and return a boolean indicating if the input op should
be checked or not.
bug_print: string belonging to {'raise', 'warn', 'ignore'}
You can specify the behaviour of the function when the specified
ops_to_check are not in the graph of f_or_fgraph: it can either raise
an exception, write a warning or simply ignore it.
Returns
-------
boolean
True if the outputs of the specified ops have a stack, False otherwise.
""" """
if isinstance(f_or_fgraph, theano.compile.function_module.Function): if isinstance(f_or_fgraph, theano.compile.function_module.Function):
fgraph = f_or_fgraph.maker.fgraph fgraph = f_or_fgraph.maker.fgraph
elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph): elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph):
fgraph = f_or_fgraph fgraph = f_or_fgraph
else: else:
raise ValueError('The type of f_f_or_fgraph is not supported') raise ValueError('The type of f_or_fgraph is not supported')
if isinstance(ops_to_check, string_types): if isinstance(ops_to_check, string_types):
if ops_to_check == 'last': if ops_to_check == 'last':
...@@ -2648,12 +2662,27 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last'): ...@@ -2648,12 +2662,27 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last'):
else: else:
raise ValueError('The string ops_to_check is not recognised') raise ValueError('The string ops_to_check is not recognised')
else: elif hasattr(ops_to_check, '__call__'): # if ops_to_check is a function
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if ops_to_check(node)]
else: # if ops_to_check is an op or a list of ops
apply_nodes_to_check = [node for node in fgraph.apply_nodes apply_nodes_to_check = [node for node in fgraph.apply_nodes
if isinstance(node.op, ops_to_check)] if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check: if not apply_nodes_to_check:
warnings.warn('Provided ops are not in the graph') msg = 'Provided ops are not in the graph'
if bug_print == 'warn':
warnings.warn(msg)
elif bug_print == 'raise':
raise Exception(msg)
elif bug_print == 'ignore':
pass
else:
raise ValueError('The string bug_print is not recognised')
for node in apply_nodes_to_check: for node in apply_nodes_to_check:
for output in node.outputs: for output in node.outputs:
assert hasattr(output.tag, 'trace') if not hasattr(output.tag, 'trace'):
return False
return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论