提交 ee47964c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4290 from adbrebs/stack_trace_opt

[WIP] Helper function that check stack traces
...@@ -2568,3 +2568,121 @@ def pre_greedy_local_optimizer(list_optimizations, out): ...@@ -2568,3 +2568,121 @@ def pre_greedy_local_optimizer(list_optimizations, out):
final_outs, optimized_nodes = local_recursive_function( final_outs, optimized_nodes = local_recursive_function(
list_optimizations, out, {}, 0) list_optimizations, out, {}, 0)
return final_outs[out_index] return final_outs[out_index]
def copy_stack_trace(from_var, to_var):
"""
Copies the stack trace from one or more tensor variables to
one or more tensor variables.
Parameters
----------
from_var
Tensor variable or list of tensor variables to copy stack traces from.
to_var
Tensor variable or list of tensor variables to copy stack traces to.
Notes
-----
The stacktrace is assumed to be of the form of a list of lists
of tuples. Each tuple contains the filename, line number, function name
and so on. Each list of tuples contains the truples belonging to a
particular variable.
"""
# Store stack traces from from_var
tr = []
if type(from_var) is list:
# If from_var is a list, store concatenated stack traces
for v in from_var:
tr += getattr(v.tag, 'trace', [])
else:
# If from_var is not a list, it must be a single tensor variable,
# so just store that particular stack trace
tr = getattr(from_var.tag, 'trace', [])
# Copy over stack traces to to_var
if type(to_var) is list:
# Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before
for v in to_var:
v.tag.trace = getattr(v.tag, 'trace', []) + tr
else:
# Copy over stack traces from from_var to each variable to
# to_var, including the stack_trace of the to_var before
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
"""
This function checks if the outputs of specific ops of a compiled graph
have a stack.
Parameters
----------
f_or_fgraph: theano.compile.function_module.Function or
theano.gof.fg.FunctionGraph
The compiled function or the function graph to be analysed.
ops_to_check: theano.gof.Op or tuple of theano.gof.Op or a string or a
function returning a boolean and taking as input a theano.gof.Op.
- if ops_to_check is a string, it should be either 'last' or 'all'.
'last' will check only the last op of the graph while 'all' will
check all the ops of the graph.
- if ops_to_check is an op or a tuple of ops, the function will check
that all the outputs of their occurrences in the graph have a stack
trace.
- if ops_to_check is a function, it should take as input a
theano.gof.Op and return a boolean indicating if the input op should
be checked or not.
bug_print: string belonging to {'raise', 'warn', 'ignore'}
You can specify the behaviour of the function when the specified
ops_to_check are not in the graph of f_or_fgraph: it can either raise
an exception, write a warning or simply ignore it.
Returns
-------
boolean
True if the outputs of the specified ops have a stack, False otherwise.
"""
if isinstance(f_or_fgraph, theano.compile.function_module.Function):
fgraph = f_or_fgraph.maker.fgraph
elif isinstance(f_or_fgraph, theano.gof.fg.FunctionGraph):
fgraph = f_or_fgraph
else:
raise ValueError('The type of f_or_fgraph is not supported')
if isinstance(ops_to_check, string_types):
if ops_to_check == 'last':
apply_nodes_to_check = [fgraph.outputs[0].owner]
elif ops_to_check == 'all':
apply_nodes_to_check = fgraph.apply_nodes
else:
raise ValueError('The string ops_to_check is not recognised')
elif hasattr(ops_to_check, '__call__'): # if ops_to_check is a function
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if ops_to_check(node)]
else: # if ops_to_check is an op or a list of ops
apply_nodes_to_check = [node for node in fgraph.apply_nodes
if isinstance(node.op, ops_to_check)]
if not apply_nodes_to_check:
msg = 'Provided ops are not in the graph'
if bug_print == 'warn':
warnings.warn(msg)
elif bug_print == 'raise':
raise Exception(msg)
elif bug_print == 'ignore':
pass
else:
raise ValueError('The string bug_print is not recognised')
for node in apply_nodes_to_check:
for output in node.outputs:
if not hasattr(output.tag, 'trace'):
return False
return True
...@@ -2,9 +2,9 @@ from __future__ import absolute_import, print_function, division ...@@ -2,9 +2,9 @@ from __future__ import absolute_import, print_function, division
import theano import theano
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gof import Op, Apply, TopoOptimizer from theano.gof import Op, Apply, TopoOptimizer
from theano.gof.opt import copy_stack_trace
from theano import tensor from theano import tensor
import theano.sandbox.cuda as cuda import theano.sandbox.cuda as cuda
from theano.tensor.opt import copy_stack_trace
def get_diagonal_subtensor_view(x, i0, i1): def get_diagonal_subtensor_view(x, i0, i1):
......
...@@ -20,10 +20,10 @@ from six.moves import xrange ...@@ -20,10 +20,10 @@ from six.moves import xrange
import theano import theano
from theano import gof from theano import gof
from theano import scalar from theano import scalar
from theano.gof.opt import copy_stack_trace
from theano.tensor import basic as tensor, subtensor, opt, elemwise from theano.tensor import basic as tensor, subtensor, opt, elemwise
from theano.tensor.type import (values_eq_approx_remove_inf, from theano.tensor.type import (values_eq_approx_remove_inf,
values_eq_approx_remove_nan) values_eq_approx_remove_nan)
from theano.tensor.opt import copy_stack_trace
from theano.compile import optdb from theano.compile import optdb
from theano.gof import Apply from theano.gof import Apply
......
...@@ -6,6 +6,7 @@ import theano ...@@ -6,6 +6,7 @@ import theano
from theano import compile, gof from theano import compile, gof
from theano.compile import optdb from theano.compile import optdb
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.opt import copy_stack_trace
from theano.tensor.nnet.corr import ( from theano.tensor.nnet.corr import (
CorrMM, CorrMM_gradInputs, CorrMM_gradWeights) CorrMM, CorrMM_gradInputs, CorrMM_gradWeights)
...@@ -18,8 +19,7 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d, ...@@ -18,8 +19,7 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
AbstractConv2d_gradWeights, AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs) AbstractConv2d_gradInputs)
from theano.tensor.nnet.abstract_conv import get_conv_output_shape from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.opt import (copy_stack_trace, from theano.tensor.opt import register_specialize_device
register_specialize_device)
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.tensor import opt from theano.tensor import opt
......
...@@ -18,7 +18,7 @@ from theano.printing import pprint ...@@ -18,7 +18,7 @@ from theano.printing import pprint
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
from theano.tensor import elemwise, opt, NotScalarConstantError from theano.tensor import elemwise, opt, NotScalarConstantError
from theano.tensor.type import values_eq_approx_remove_inf from theano.tensor.type import values_eq_approx_remove_inf
from theano.tensor.opt import copy_stack_trace from theano.gof.opt import copy_stack_trace
############ ############
# #
......
...@@ -22,6 +22,7 @@ from theano import gof ...@@ -22,6 +22,7 @@ from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.configparser import config from theano.configparser import config
...@@ -54,51 +55,6 @@ _logger = logging.getLogger('theano.tensor.opt') ...@@ -54,51 +55,6 @@ _logger = logging.getLogger('theano.tensor.opt')
# Utilities # Utilities
def copy_stack_trace(from_var, to_var):
"""
Copies the stack trace from one or more tensor variables to
one or more tensor variables.
Parameters
----------
from_var
Tensor variable or list of tensor variables to copy stack traces from.
to_var
Tensor variable or list of tensor variables to copy stack traces to.
Notes
-----
The stacktrace is assumed to be of the form of a list of lists
of tuples. Each tuple contains the filename, line number, function name
and so on. Each list of tuples contains the truples belonging to a
particular variable.
"""
# Store stack traces from from_var
tr = []
if type(from_var) is list:
# If from_var is a list, store concatenated stack traces
for v in from_var:
tr += getattr(v.tag, 'trace', [])
else:
# If from_var is not a list, it must be a single tensor variable,
# so just store that particular stack trace
tr = getattr(from_var.tag, 'trace', [])
# Copy over stack traces to to_var
if type(to_var) is list:
# Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before
for v in to_var:
v.tag.trace = getattr(v.tag, 'trace', []) + tr
else:
# Copy over stack traces from from_var to each variable to
# to_var, including the stack_trace of the to_var before
to_var.tag.trace = getattr(to_var.tag, 'trace', []) + tr
def out2in(*local_opts, **kwargs): def out2in(*local_opts, **kwargs):
"""WRITEME """ """WRITEME """
name = (kwargs and kwargs.pop('name', None)) name = (kwargs and kwargs.pop('name', None))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论