提交 e13ccfb5 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2154 from caglar/scan_debugprint

WIP: Fixed #2079
...@@ -494,7 +494,8 @@ def char_from_number(number): ...@@ -494,7 +494,8 @@ def char_from_number(number):
def debugprint(r, prefix='', depth=-1, done=None, print_type=False, def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
file=sys.stdout, print_destroy_map=False, file=sys.stdout, print_destroy_map=False,
print_view_map=False, order=None, ids='CHAR', print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None): stop_on_name=False, prefix_child=None,
scan_ops=None):
"""Print the graph leading to `r` to given depth. """Print the graph leading to `r` to given depth.
:param r: Variable instance :param r: Variable instance
...@@ -502,10 +503,10 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -502,10 +503,10 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
:param depth: maximum recursion depth (Default -1 for unlimited). :param depth: maximum recursion depth (Default -1 for unlimited).
:param done: dict of Apply instances that have already been printed :param done: dict of Apply instances that have already been printed
and their associated printed ids and their associated printed ids
:param print_type: wether to print the Variable type after the other infos :param print_type: whether to print the Variable type after the other infos
:param file: file-like object to which to print :param file: file-like object to which to print
:param print_destroy_map: wether to print the op destroy_map after ofther info :param print_destroy_map: whether to print the op destroy_map after other info
:param print_view_map: wether to print the op view_map after ofther info :param print_view_map: whether to print the op view_map after other info
:param order: If not empty will print the index in the toposort. :param order: If not empty will print the index in the toposort.
:param ids: How do we print the identifier of the variable :param ids: How do we print the identifier of the variable
id - print the python id value id - print the python id value
...@@ -514,6 +515,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -514,6 +515,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
"" - don't print an identifier "" - don't print an identifier
:param stop_on_name: When True, if a node in the graph has a name, :param stop_on_name: When True, if a node in the graph has a name,
we don't print anything below it. we don't print anything below it.
:param scan_ops: Scan ops in the graph will be added inside this list
for later printing purposes.
""" """
if depth == 0: if depth == 0:
...@@ -525,6 +528,9 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -525,6 +528,9 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
if done is None: if done is None:
done = dict() done = dict()
if scan_ops is None:
scan_ops = []
if print_type: if print_type:
type_str = ' <%s>' % r.type type_str = ' <%s>' % r.type
else: else:
...@@ -575,37 +581,45 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -575,37 +581,45 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
o = '' o = ''
if order: if order:
o = str(order.index(r.owner)) o = str(order.index(r.owner))
already_printed = a in done # get_id_str put it in the dict already_printed = a in done # get_id_str put it in the dict
id_str = get_id_str(a) id_str = get_id_str(a)
if len(a.outputs) == 1: if len(a.outputs) == 1:
print >> file, '%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op, print >> file, '%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
id_str, id_str,
type_str, r_name, type_str,
r_name,
destroy_map_str,
view_map_str,
o)
else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str, destroy_map_str,
view_map_str, view_map_str,
o) o)
else:
print >> file, '%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op,
a.outputs.index(r),
id_str, type_str,
r_name,
destroy_map_str,
view_map_str,
o)
if not already_printed: if not already_printed:
if (not stop_on_name or if (not stop_on_name or
not (hasattr(r, 'name') and r.name is not None)): not (hasattr(r, 'name') and r.name is not None)):
new_prefix = prefix_child + ' |' new_prefix = prefix_child + ' |'
new_prefix_child = prefix_child + ' |' new_prefix_child = prefix_child + ' |'
for idx, i in enumerate(a.inputs): for idx, i in enumerate(a.inputs):
if idx == len(a.inputs) - 1: if idx == len(a.inputs) - 1:
new_prefix_child = prefix_child + ' ' new_prefix_child = prefix_child + ' '
if hasattr(i, 'owner') and hasattr(i.owner, 'op'):
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan):
scan_ops.append(i)
debugprint(i, new_prefix, depth=depth - 1, done=done, debugprint(i, new_prefix, depth=depth - 1, done=done,
print_type=print_type, file=file, order=order, print_type=print_type, file=file, order=order,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child) prefix_child=new_prefix_child, scan_ops=scan_ops)
else: else:
#this is an input variable #this is an input variable
id_str = get_id_str(r) id_str = get_id_str(r)
...@@ -624,7 +638,6 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -624,7 +638,6 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
:type accept_inplace: Bool :type accept_inplace: Bool
:rtype: `FunctionGraph` :rtype: `FunctionGraph`
:returns: a new FunctionGraph with a cloned graph, with debugging `Feature` instances already installed. :returns: a new FunctionGraph with a cloned graph, with debugging `Feature` instances already installed.
""" """
orig_inputs = [spec.variable for spec in input_specs] orig_inputs = [spec.variable for spec in input_specs]
updates = [spec.update for spec in input_specs if spec.update] updates = [spec.update for spec in input_specs if spec.update]
...@@ -2152,7 +2165,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2152,7 +2165,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
# Check if some input variables are unused # Check if some input variables are unused
self._check_unused_inputs(inputs, outputs, on_unused_input) self._check_unused_inputs(inputs, outputs, on_unused_input)
# Make a list of (SymbolicInput|SymblicInputKits, indices, [SymbolicInput,...]), one # Make a list of (SymbolicInput|SymblicInputKits, indices, [SymbolicInput,...]), one
# tuple for each input. (See Function.indices for more details) # tuple for each input. (See Function.indices for more details)
indices = [[input] + self.expand_in(input, _inputs) for input in inputs] indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
......
...@@ -102,10 +102,38 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -102,10 +102,38 @@ def debugprint(obj, depth=-1, print_type=False,
else: else:
raise TypeError("debugprint cannot print an object of this type", raise TypeError("debugprint cannot print an object of this type",
obj) obj)
scan_ops = []
for r in results_to_print: for r in results_to_print:
#Add the parent scan op to the list as well
if hasattr(r.owner, 'op') and isinstance(r.owner.op, theano.scan_module.scan_op.Scan):
scan_ops.append(r)
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids, file=_file, order=order, ids=ids,
stop_on_name=stop_on_name) scan_ops=scan_ops, stop_on_name=stop_on_name)
if len(scan_ops) > 0:
print >> file, ""
new_prefix = ' >'
new_prefix_child = ' >'
print >> file, "Inner graphs of the scan ops:"
for s in scan_ops:
print >> file, ""
debugmode.debugprint(s, depth=depth, done=done, print_type=print_type,
file=_file, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name)
for idx, i in enumerate(s.owner.op.outputs):
if hasattr(i, 'owner') and hasattr(i.owner, 'op'):
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan):
scan_ops.append(i)
debugmode.debugprint(r=i, prefix=new_prefix, depth=depth, done=done,
print_type=print_type, file=file,
ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops)
if file is _file: if file is _file:
return file return file
elif file == 'str': elif file == 'str':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论