提交 7f3f740f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4407 from adbrebs/update_check_trace

Update check trace to allow more argument types
...@@ -7,6 +7,7 @@ from __future__ import absolute_import, print_function, division ...@@ -7,6 +7,7 @@ from __future__ import absolute_import, print_function, division
from collections import deque from collections import deque
import copy import copy
import inspect
import logging import logging
import pdb import pdb
import sys import sys
...@@ -2623,16 +2624,20 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -2623,16 +2624,20 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
Parameters Parameters
---------- ----------
f_or_fgraph: theano.compile.function_module.Function or f_or_fgraph: theano.compile.function_module.Function or
theano.gof.fg.FunctionGraph theano.gof.fg.FunctionGraph
The compiled function or the function graph to be analysed. The compiled function or the function graph to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string or a ops_to_check: it can be of four different types:
function returning a boolean and taking as input a theano.gof.Op. - classes or instances inheriting from theano.gof.Op
- tuple/list of classes or instances inheriting from theano.gof.Op
- string
- function returning a boolean and taking as input an instance of
theano.gof.Op.
- if ops_to_check is a string, it should be either 'last' or 'all'. - if ops_to_check is a string, it should be either 'last' or 'all'.
'last' will check only the last op of the graph while 'all' will 'last' will check only the last op of the graph while 'all' will
check all the ops of the graph. check all the ops of the graph.
- if ops_to_check is an op or a tuple of ops, the function will check - if ops_to_check is an op or a tuple/list of ops, the function will
that all the outputs of their occurrences in the graph have a stack check that all the outputs of their occurrences in the graph have a
trace. stack trace.
- if ops_to_check is a function, it should take as input a - if ops_to_check is a function, it should take as input a
theano.gof.Op and return a boolean indicating if the input op should theano.gof.Op and return a boolean indicating if the input op should
be checked or not. be checked or not.
...@@ -2654,6 +2659,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -2654,6 +2659,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
else: else:
raise ValueError('The type of f_or_fgraph is not supported') raise ValueError('The type of f_or_fgraph is not supported')
if (isinstance(ops_to_check, theano.gof.Op) or
(inspect.isclass(ops_to_check) and
issubclass(ops_to_check, theano.gof.Op))):
ops_to_check = (ops_to_check,)
# if ops_to_check is a string
if isinstance(ops_to_check, string_types): if isinstance(ops_to_check, string_types):
if ops_to_check == 'last': if ops_to_check == 'last':
apply_nodes_to_check = [fgraph.outputs[0].owner] apply_nodes_to_check = [fgraph.outputs[0].owner]
...@@ -2662,27 +2673,49 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -2662,27 +2673,49 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
else: else:
raise ValueError('The string ops_to_check is not recognised') raise ValueError('The string ops_to_check is not recognised')
elif hasattr(ops_to_check, '__call__'): # if ops_to_check is a function # if ops_to_check is a list/tuple of ops
elif isinstance(ops_to_check, (tuple, list)):
# Separate classes from instances in ops_to_check
op_instances = []
op_classes = []
for obj in ops_to_check:
if isinstance(obj, theano.gof.Op):
op_instances.append(obj)
else:
op_classes.append(obj)
op_classes = tuple(op_classes)
apply_nodes_to_check = (
[node for node in fgraph.apply_nodes if node.op in ops_to_check] +
[node for node in fgraph.apply_nodes
if isinstance(node.op, op_classes) or
(hasattr(node.op, 'scalar_op') and
isinstance(node.op.scalar_op, op_classes))])
# if ops_to_check is a function
elif hasattr(ops_to_check, '__call__'):
apply_nodes_to_check = [node for node in fgraph.apply_nodes apply_nodes_to_check = [node for node in fgraph.apply_nodes
if ops_to_check(node)] if ops_to_check(node)]
else: # if ops_to_check is an op or a list of ops else:
apply_nodes_to_check = [node for node in fgraph.apply_nodes raise ValueError('ops_to_check does not have the right type')
if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check: if not apply_nodes_to_check:
msg = 'Provided ops are not in the graph' msg = 'Provided op instances/classes are not in the graph or the ' \
if bug_print == 'warn': 'graph is empty'
warnings.warn(msg) if bug_print == 'warn':
elif bug_print == 'raise': warnings.warn(msg)
raise Exception(msg) elif bug_print == 'raise':
elif bug_print == 'ignore': raise Exception(msg)
pass elif bug_print == 'ignore':
else: pass
raise ValueError('The string bug_print is not recognised') else:
raise ValueError('The string bug_print is not recognised')
for node in apply_nodes_to_check: for node in apply_nodes_to_check:
for output in node.outputs: for output in node.outputs:
if not hasattr(output.tag, 'trace'): if (not hasattr(output.tag, 'trace') or
not output.tag.trace):
return False return False
return True return True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论