提交 6b620e2f authored 作者: Amjad Almahairi's avatar Amjad Almahairi

improving debugprint for scan ops

上级 d9fc9d73
...@@ -521,7 +521,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -521,7 +521,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
file=sys.stdout, print_destroy_map=False, file=sys.stdout, print_destroy_map=False,
print_view_map=False, order=None, ids='CHAR', print_view_map=False, order=None, ids='CHAR',
stop_on_name=False, prefix_child=None, stop_on_name=False, prefix_child=None,
scan_ops=None, profile=None): scan_ops=None, profile=None,
scan_inner_to_outer_inputs=None):
"""Print the graph leading to `r` to given depth. """Print the graph leading to `r` to given depth.
:param r: Variable instance :param r: Variable instance
...@@ -544,6 +545,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -544,6 +545,8 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
we don't print anything below it. we don't print anything below it.
:param scan_ops: Scan ops in the graph will be added inside this list :param scan_ops: Scan ops in the graph will be added inside this list
for later printing purposes. for later printing purposes.
:param scan_inner_to_outer_inputs: a dictionary mapping a scan ops inner function
inputs to the scan op inputs (outer inputs) for printing purposes.
""" """
if depth == 0: if depth == 0:
...@@ -578,6 +581,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -578,6 +581,7 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
elif ids == "": elif ids == "":
id_str = "" id_str = ""
done[obj] = id_str done[obj] = id_str
return id_str return id_str
if hasattr(r.owner, 'op'): if hasattr(r.owner, 'op'):
...@@ -681,13 +685,25 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False, ...@@ -681,13 +685,25 @@ def debugprint(r, prefix='', depth=-1, done=None, print_type=False,
debugprint(i, new_prefix, depth=depth - 1, done=done, debugprint(i, new_prefix, depth=depth - 1, done=done,
print_type=print_type, file=file, order=order, print_type=print_type, file=file, order=order,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, prefix_child=new_prefix_child, scan_ops=scan_ops,
scan_ops=scan_ops, profile=profile) profile=profile,
scan_inner_to_outer_inputs=scan_inner_to_outer_inputs)
else: else:
# this is an input variable if scan_inner_to_outer_inputs is not None and\
id_str = get_id_str(r) r in scan_inner_to_outer_inputs:
print('%s%s %s%s' % (prefix, r, id_str, type_str), file=file)
id_str = get_id_str(r)
outer_r = scan_inner_to_outer_inputs[r]
if hasattr(outer_r.owner, 'op'):
outer_id_str = get_id_str(outer_r.owner)
else:
outer_id_str = get_id_str(outer_r)
print('%s%s %s%s -> %s' % (prefix, r, id_str, type_str, outer_id_str), file=file)
else:
# this is an input variable
id_str = get_id_str(r)
print('%s%s %s%s' % (prefix, r, id_str, type_str), file=file)
return file return file
......
...@@ -149,6 +149,7 @@ N.B.: ...@@ -149,6 +149,7 @@ N.B.:
file=_file, order=order, ids=ids, file=_file, order=order, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name, scan_ops=scan_ops, stop_on_name=stop_on_name,
profile=p) profile=p)
if len(scan_ops) > 0: if len(scan_ops) > 0:
print("", file=_file) print("", file=_file)
new_prefix = ' >' new_prefix = ' >'
...@@ -156,17 +157,31 @@ N.B.: ...@@ -156,17 +157,31 @@ N.B.:
print("Inner graphs of the scan ops:", file=_file) print("Inner graphs of the scan ops:", file=_file)
for s in scan_ops: for s in scan_ops:
# prepare a dict which maps the scan op's inner inputs to its outer inputs.
if hasattr(s.owner.op, 'fn'):
# If the op was compiled, print the optimized version.
inner_inputs = s.owner.op.fn.maker.fgraph.inputs
else:
inner_inputs = s.owner.op.inputs
outer_inputs = s.owner.inputs
inner_to_outer_inputs = dict([(inner_inputs[i],outer_inputs[o])
for i,o in enumerate(
s.owner.op.get_outer_iidx_from_inner_iidx_seq())])
#import pdb; pdb.set_trace()
print("", file=_file) print("", file=_file)
debugmode.debugprint(s, depth=depth, done=done, debugmode.debugprint(s, depth=depth, done=done,
print_type=print_type, print_type=print_type,
file=_file, ids=ids, file=_file, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name) scan_ops=scan_ops, stop_on_name=stop_on_name,
scan_inner_to_outer_inputs=inner_to_outer_inputs)
if hasattr(s.owner.op, 'fn'): if hasattr(s.owner.op, 'fn'):
# If the op was compiled, print the optimized version. # If the op was compiled, print the optimized version.
outputs = s.owner.op.fn.maker.fgraph.outputs outputs = s.owner.op.fn.maker.fgraph.outputs
else: else:
outputs = s.owner.op.outputs outputs = s.owner.op.outputs
for idx, i in enumerate(outputs): for idx, i in enumerate(outputs):
if hasattr(i, 'owner') and hasattr(i.owner, 'op'): if hasattr(i, 'owner') and hasattr(i.owner, 'op'):
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan): if isinstance(i.owner.op, theano.scan_module.scan_op.Scan):
scan_ops.append(i) scan_ops.append(i)
...@@ -176,7 +191,8 @@ N.B.: ...@@ -176,7 +191,8 @@ N.B.:
print_type=print_type, file=_file, print_type=print_type, file=_file,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, prefix_child=new_prefix_child,
scan_ops=scan_ops) scan_ops=scan_ops,
scan_inner_to_outer_inputs=inner_to_outer_inputs)
if file is _file: if file is _file:
return file return file
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论