提交 a65ebd06 authored 作者: AdeB's avatar AdeB

Update check_stack_trace to allow more arguments and types

上级 a536464a
......@@ -7,6 +7,7 @@ from __future__ import absolute_import, print_function, division
from collections import deque
import copy
import inspect
import logging
import pdb
import sys
......@@ -2623,16 +2624,20 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
Parameters
----------
f_or_fgraph: theano.compile.function_module.Function or
theano.gof.fg.FunctionGraph
theano.gof.fg.FunctionGraph
The compiled function or the function graph to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string or a
function returning a boolean and taking as input a theano.gof.Op.
ops_to_check: it can be of four different types:
- classes or instances inheriting from theano.gof.Op
- tuple/list of classes or instances inheriting from theano.gof.Op
- string
- function returning a boolean and taking as input an instance of
theano.gof.Op.
- if ops_to_check is a string, it should be either 'last' or 'all'.
'last' will check only the last op of the graph while 'all' will
check all the ops of the graph.
- if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack
trace.
- if ops_to_check is an op or a tuple/list of ops, the function will
check that all the outputs of their occurrences in the graph have a
stack trace.
- if ops_to_check is a function, it should take as input a
theano.gof.Op and return a boolean indicating if the input op should
be checked or not.
......@@ -2654,6 +2659,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
else:
raise ValueError('The type of f_or_fgraph is not supported')
if (isinstance(ops_to_check, theano.gof.Op) or
(inspect.isclass(ops_to_check) and
issubclass(ops_to_check, theano.gof.Op))):
ops_to_check = (ops_to_check,)
# if ops_to_check is a string
if isinstance(ops_to_check, string_types):
if ops_to_check == 'last':
apply_nodes_to_check = [fgraph.outputs[0].owner]
......@@ -2662,27 +2673,39 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
else:
raise ValueError('The string ops_to_check is not recognised')
elif hasattr(ops_to_check, '__call__'): # if ops_to_check is a function
# if ops_to_check is a list/tuple of ops
elif isinstance(ops_to_check, (tuple, list)):
ops_to_check = tuple(ops_to_check)
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if ops_to_check(node)]
if node.op in ops_to_check or
isinstance(node.op, ops_to_check) or
(hasattr(node.op, 'scalar_op') and
isinstance(node.op.scalar_op, ops_to_check))]
else: # if ops_to_check is an op or a list of ops
# if ops_to_check is a function
elif hasattr(ops_to_check, '__call__'):
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check:
msg = 'Provided ops are not in the graph'
if bug_print == 'warn':
warnings.warn(msg)
elif bug_print == 'raise':
raise Exception(msg)
elif bug_print == 'ignore':
pass
else:
raise ValueError('The string bug_print is not recognised')
if ops_to_check(node)]
else:
raise ValueError('ops_to_check does not have the right type')
if not apply_nodes_to_check:
msg = 'Provided op instances/classes are not in the graph or the ' \
'graph is empty'
if bug_print == 'warn':
warnings.warn(msg)
elif bug_print == 'raise':
raise Exception(msg)
elif bug_print == 'ignore':
pass
else:
raise ValueError('The string bug_print is not recognised')
for node in apply_nodes_to_check:
for output in node.outputs:
if not hasattr(output.tag, 'trace'):
if (not hasattr(output.tag, 'trace') or
not output.tag.trace):
return False
return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论