提交 ee47964c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4290 from adbrebs/stack_trace_opt

[WIP] Helper function that check stack traces
......@@ -2568,3 +2568,121 @@ def pre_greedy_local_optimizer(list_optimizations, out):
final_outs, optimized_nodes = local_recursive_function(
list_optimizations, out, {}, 0)
return final_outs[out_index]
def copy_stack_trace(from_var, to_var):
"""
Copies the stack trace from one or more tensor variables to
one or more tensor variables.
Parameters
----------
from_var
Tensor variable or list of tensor variables to copy stack traces from.
to_var
Tensor variable or list of tensor variables to copy stack traces to.
Notes
-----
The stacktrace is assumed to be of the form of a list of lists
of tuples. Each tuple contains the filename, line number, function name
and so on. Each list of tuples contains the truples belonging to a
particular variable.
"""
# Store stack traces from from_var
tr = []
if type(from_var) is list:
# If from_var is a list, store concatenated stack traces
for v in from_var:
tr += getattr(v.tag, 'trace', [])
else:
# If from_var is not a list, it must be a single tensor variable,
# so just store that particular stack trace
tr = getattr(from_var.tag, 'trace', [])
# Copy over stack traces to to_var
if type(to_var) is list:
# Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before
for v in to_var:
v.tag.trace = getattr(v.tag, 'trace', []) + tr
else:
# Copy over stack traces from from_var to each variable to
# to_var, including the stack_trace of the to_var before
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
"""
This function checks if the outputs of specific ops of a compiled graph
have a stack.
Parameters
----------
f_or_fgraph: theano.compile.function_module.Function or
theano.gof.fg.FunctionGraph
The compiled function or the function graph to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string or a
function returning a boolean and taking as input a theano.gof.Op.
- if ops_to_check is a string, it should be either 'last' or 'all'.
'last' will check only the last op of the graph while 'all' will
check all the ops of the graph.
- if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack
trace.
- if ops_to_check is a function, it should take as input a
theano.gof.Op and return a boolean indicating if the input op should
be checked or not.
bug_print: string belonging to {'raise', 'warn', 'ignore'}
You can specify the behaviour of the function when the specified
ops_to_check are not in the graph of f_or_fgraph: it can either raise
an exception, write a warning or simply ignore it.
Returns
-------
boolean
True if the outputs of the specified ops have a stack, False otherwise.
"""
if isinstance(f_or_fgraph, theano.compile.function_module.Function):
fgraph = f_or_fgraph.maker.fgraph
elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph):
fgraph = f_or_fgraph
else:
raise ValueError('The type of f_or_fgraph is not supported')
if isinstance(ops_to_check, string_types):
if ops_to_check == 'last':
apply_nodes_to_check = [fgraph.outputs[0].owner]
elif ops_to_check == 'all':
apply_nodes_to_check = fgraph.apply_nodes
else:
raise ValueError('The string ops_to_check is not recognised')
elif hasattr(ops_to_check, '__call__'): # if ops_to_check is a function
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if ops_to_check(node)]
else: # if ops_to_check is an op or a list of ops
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check:
msg = 'Provided ops are not in the graph'
if bug_print == 'warn':
warnings.warn(msg)
elif bug_print == 'raise':
raise Exception(msg)
elif bug_print == 'ignore':
pass
else:
raise ValueError('The string bug_print is not recognised')
for node in apply_nodes_to_check:
for output in node.outputs:
if not hasattr(output.tag, 'trace'):
return False
return True
......@@ -2,9 +2,9 @@ from __future__ import absolute_import, print_function, division
import theano
from theano.gradient import DisconnectedType
from theano.gof import Op, Apply, TopoOptimizer
from theano.gof.opt import copy_stack_trace
from theano import tensor
import theano.sandbox.cuda as cuda
from theano.tensor.opt import copy_stack_trace
def get_diagonal_subtensor_view(x, i0, i1):
......
......@@ -20,10 +20,10 @@ from six.moves import xrange
import theano
from theano import gof
from theano import scalar
from theano.gof.opt import copy_stack_trace
from theano.tensor import basic as tensor, subtensor, opt, elemwise
from theano.tensor.type import (values_eq_approx_remove_inf,
values_eq_approx_remove_nan)
from theano.tensor.opt import copy_stack_trace
from theano.compile import optdb
from theano.gof import Apply
......
......@@ -6,6 +6,7 @@ import theano
from theano import compile, gof
from theano.compile import optdb
from theano.gof import local_optimizer
from theano.gof.opt import copy_stack_trace
from theano.tensor.nnet.corr import (
CorrMM, CorrMM_gradInputs, CorrMM_gradWeights)
......@@ -18,8 +19,7 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs)
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.opt import (copy_stack_trace,
register_specialize_device)
from theano.tensor.opt import register_specialize_device
from theano.tensor import TensorType
from theano.tensor import opt
......
......@@ -18,7 +18,7 @@ from theano.printing import pprint
from theano.tensor import basic as tensor
from theano.tensor import elemwise, opt, NotScalarConstantError
from theano.tensor.type import values_eq_approx_remove_inf
from theano.tensor.opt import copy_stack_trace
from theano.gof.opt import copy_stack_trace
############
#
......
......@@ -22,6 +22,7 @@ from theano import gof
from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace
from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType
from theano.configparser import config
......@@ -54,51 +55,6 @@ _logger = logging.getLogger('theano.tensor.opt')
# Utilities
def copy_stack_trace(from_var, to_var):
"""
Copies the stack trace from one or more tensor variables to
one or more tensor variables.
Parameters
----------
from_var
Tensor variable or list of tensor variables to copy stack traces from.
to_var
Tensor variable or list of tensor variables to copy stack traces to.
Notes
-----
The stacktrace is assumed to be of the form of a list of lists
of tuples. Each tuple contains the filename, line number, function name
and so on. Each list of tuples contains the truples belonging to a
particular variable.
"""
# Store stack traces from from_var
tr = []
if type(from_var) is list:
# If from_var is a list, store concatenated stack traces
for v in from_var:
tr += getattr(v.tag, 'trace', [])
else:
# If from_var is not a list, it must be a single tensor variable,
# so just store that particular stack trace
tr = getattr(from_var.tag, 'trace', [])
# Copy over stack traces to to_var
if type(to_var) is list:
# Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before
for v in to_var:
v.tag.trace = getattr(v.tag, 'trace', []) + tr
else:
# Copy over stack traces from from_var to each variable to
# to_var, including the stack_trace of the to_var before
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def out2in(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论