提交 cd15b168 authored 作者: AdeB's avatar AdeB

separate classes from instances in check_trace

上级 a65ebd06
...@@ -2675,12 +2675,22 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): ...@@ -2675,12 +2675,22 @@ def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
# if ops_to_check is a list/tuple of ops # if ops_to_check is a list/tuple of ops
elif isinstance(ops_to_check, (tuple, list)): elif isinstance(ops_to_check, (tuple, list)):
ops_to_check = tuple(ops_to_check) # Separate classes from instances in ops_to_check
apply_nodes_to_check = [node for node in fgraph.apply_nodes op_instances = []
if node.op in ops_to_check or op_classes = []
isinstance(node.op, ops_to_check) or for obj in ops_to_check:
(hasattr(node.op, 'scalar_op') and if isinstance(obj, theano.gof.Op):
isinstance(node.op.scalar_op, ops_to_check))] op_instances.append(obj)
else:
op_classes.append(obj)
op_classes = tuple(op_classes)
apply_nodes_to_check = (
[node for node in fgraph.apply_nodes if node.op in ops_to_check] +
[node for node in fgraph.apply_nodes
if isinstance(node.op, op_classes) or
(hasattr(node.op, 'scalar_op') and
isinstance(node.op.scalar_op, op_classes))])
# if ops_to_check is a function # if ops_to_check is a function
elif hasattr(ops_to_check, '__call__'): elif hasattr(ops_to_check, '__call__'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论