提交 6485dbbc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3213 from nouiz/pydotprint

Pydotprint, fix printing of reused output and updated shared var
...@@ -28,7 +28,7 @@ from theano import gof ...@@ -28,7 +28,7 @@ from theano import gof
from theano import config from theano import config
from six.moves import StringIO, reduce from six.moves import StringIO, reduce
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano.compile import Function, debugmode from theano.compile import Function, debugmode, SharedVariable
from theano.compile.profilemode import ProfileMode from theano.compile.profilemode import ProfileMode
_logger = logging.getLogger("theano.printing") _logger = logging.getLogger("theano.printing")
...@@ -571,7 +571,7 @@ Print to the terminal a math-like expression. ...@@ -571,7 +571,7 @@ Print to the terminal a math-like expression.
default_colorCodes = {'GpuFromHost': 'red', default_colorCodes = {'GpuFromHost': 'red',
'HostFromGpu': 'red', 'HostFromGpu': 'red',
'Scan': 'yellow', 'Scan': 'yellow',
'Shape': 'cyan', 'Shape': 'brown',
'IfElse': 'magenta', 'IfElse': 'magenta',
'Elemwise': '#FFAABB', # dark pink 'Elemwise': '#FFAABB', # dark pink
'Subtensor': '#FFAAFF', # purple 'Subtensor': '#FFAAFF', # purple
...@@ -584,7 +584,6 @@ def pydotprint(fct, outfile=None, ...@@ -584,7 +584,6 @@ def pydotprint(fct, outfile=None,
max_label_size=70, scan_graphs=False, max_label_size=70, scan_graphs=False,
var_with_name_simple=False, var_with_name_simple=False,
print_output_file=True, print_output_file=True,
assert_nb_all_strings=-1,
return_image=False, return_image=False,
): ):
"""Print to a file the graph of a compiled theano function's ops. Supports """Print to a file the graph of a compiled theano function's ops. Supports
...@@ -616,10 +615,6 @@ def pydotprint(fct, outfile=None, ...@@ -616,10 +615,6 @@ def pydotprint(fct, outfile=None,
:param var_with_name_simple: If true and a variable have a name, :param var_with_name_simple: If true and a variable have a name,
we will print only the variable name. we will print only the variable name.
Otherwise, we concatenate the type to the var name. Otherwise, we concatenate the type to the var name.
:param assert_nb_all_strings: Used for tests. If non-negative, assert that
the number of unique string nodes in the dot graph is equal to
this number. This is used in tests to verify that dot won't
merge Theano nodes.
:param return_image: If True, it will create the image and return it. :param return_image: If True, it will create the image and return it.
Useful to display the image in ipython notebook. Useful to display the image in ipython notebook.
...@@ -642,11 +637,20 @@ def pydotprint(fct, outfile=None, ...@@ -642,11 +637,20 @@ def pydotprint(fct, outfile=None,
label each edge between an input and the Apply node with the label each edge between an input and the Apply node with the
input's index. input's index.
Green boxes are inputs variables to the graph, Variable color code::
blue boxes are outputs variables of the graph, - Cyan boxes are SharedVariable, inputs and/or outputs) of the graph,
grey boxes are variables that are not outputs and are not used, - Green boxes are inputs variables to the graph,
red ellipses are transfers from/to the gpu (ops with names GpuFromHost, - Blue boxes are outputs variables of the graph,
HostFromGpu). - Grey boxes are variables that are not outputs and are not used,
Default apply node code::
- Red ellipses are transfers from/to the gpu
- Yellow are scan node
- Brown are shape node
- Magenta are IfElse node
- Dark pink are elemwise node
- Purple are subtensor
- Orange are alloc node
For edges, they are black by default. If a node returns a view For edges, they are black by default. If a node returns a view
of an input, we put the corresponding input edge in blue. If it of an input, we put the corresponding input edge in blue. If it
...@@ -732,11 +736,12 @@ def pydotprint(fct, outfile=None, ...@@ -732,11 +736,12 @@ def pydotprint(fct, outfile=None,
right = list(right) right = list(right)
var_str = {} var_str = {}
var_id = {}
all_strings = set() all_strings = set()
def var_name(var): def var_name(var):
if var in var_str: if var in var_str:
return var_str[var] return var_str[var], var_id[var]
if var.name is not None: if var.name is not None:
if var_with_name_simple: if var_with_name_simple:
...@@ -749,40 +754,31 @@ def pydotprint(fct, outfile=None, ...@@ -749,40 +754,31 @@ def pydotprint(fct, outfile=None,
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type)) varstr = '%s %s' % (dstr, str(var.type))
elif (var in input_update and elif (var in input_update and
input_update[var].variable.name is not None): input_update[var].name is not None):
if var_with_name_simple: if var_with_name_simple:
varstr = input_update[var].variable.name + " UPDATE" varstr = input_update[var].variable.name
else: else:
varstr = (input_update[var].variable.name + " UPDATE " + varstr = (input_update[var].variable.name +
str(var.type)) str(var.type))
else: else:
# a var id is needed as otherwise var with the same type will be # a var id is needed as otherwise var with the same type will be
# merged in the graph. # merged in the graph.
varstr = str(var.type) varstr = str(var.type)
if (varstr in all_strings) or with_ids: if len(varstr) > max_label_size:
idx = ' id=' + str(len(var_str))
if len(varstr) + len(idx) > max_label_size:
varstr = varstr[:max_label_size - 3 - len(idx)] + idx + '...'
else:
varstr = varstr + idx
elif len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...' varstr = varstr[:max_label_size - 3] + '...'
idx = 1
while varstr in all_strings:
idx += 1
suffix = ' id=' + str(idx)
varstr = (varstr[:max_label_size - 3 - len(suffix)] +
'...' +
suffix)
var_str[var] = varstr var_str[var] = varstr
var_id[var] = str(id(var))
all_strings.add(varstr) all_strings.add(varstr)
return varstr return varstr, var_id[var]
apply_name_cache = {} apply_name_cache = {}
apply_name_id = {}
def apply_name(node): def apply_name(node):
if node in apply_name_cache: if node in apply_name_cache:
return apply_name_cache[node] return apply_name_cache[node], apply_name_id[node]
prof_str = '' prof_str = ''
if mode: if mode:
time = mode.profile_stats[fct].apply_time.get(node, 0) time = mode.profile_stats[fct].apply_time.get(node, 0)
...@@ -825,22 +821,29 @@ def pydotprint(fct, outfile=None, ...@@ -825,22 +821,29 @@ def pydotprint(fct, outfile=None,
all_strings.add(applystr) all_strings.add(applystr)
apply_name_cache[node] = applystr apply_name_cache[node] = applystr
return applystr apply_name_id[node] = str(id(node))
return applystr, apply_name_id[node]
# Update the inputs that have an update function # Update the inputs that have an update function
input_update = {} input_update = {}
reverse_input_update = {}
# Here outputs can be the original list, as we should not change # Here outputs can be the original list, as we should not change
# it, we must copy it. # it, we must copy it.
outputs = list(outputs) outputs = list(outputs)
if isinstance(fct, Function): if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs): for i, fg_ii in reversed(zip(fct.maker.expanded_inputs,
fct.maker.fgraph.inputs)):
if i.update is not None: if i.update is not None:
input_update[outputs.pop()] = i k = outputs.pop()
# Use the fgaph.inputs as it isn't the same as maker.inputs
input_update[k] = fg_ii
reverse_input_update[fg_ii] = k
apply_shape = 'ellipse' apply_shape = 'ellipse'
var_shape = 'box' var_shape = 'box'
for node_idx, node in enumerate(topo): for node_idx, node in enumerate(topo):
astr = apply_name(node) astr, aid = apply_name(node)
use_color = None use_color = None
for opName, color in iteritems(colorCodes): for opName, color in iteritems(colorCodes):
...@@ -848,12 +851,14 @@ def pydotprint(fct, outfile=None, ...@@ -848,12 +851,14 @@ def pydotprint(fct, outfile=None,
use_color = color use_color = color
if use_color is None: if use_color is None:
nw_node = pd.Node(astr, shape=apply_shape) nw_node = pd.Node(aid, label=astr, shape=apply_shape)
elif high_contrast: elif high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color, nw_node = pd.Node(aid, label=astr, style='filled',
fillcolor=use_color,
shape=apply_shape) shape=apply_shape)
else: else:
nw_node = pd.Node(astr, color=use_color, shape=apply_shape) nw_node = pd.Node(aid, label=astr,
color=use_color, shape=apply_shape)
g.add_node(nw_node) g.add_node(nw_node)
if cond_highlight: if cond_highlight:
if node in middle: if node in middle:
...@@ -863,63 +868,100 @@ def pydotprint(fct, outfile=None, ...@@ -863,63 +868,100 @@ def pydotprint(fct, outfile=None,
elif node in right: elif node in right:
c2.add_node(nw_node) c2.add_node(nw_node)
for id, var in enumerate(node.inputs): for idx, var in enumerate(node.inputs):
varstr = var_name(var) varstr, varid = var_name(var)
label = str(var.type) label = ""
if len(node.inputs) > 1: if len(node.inputs) > 1:
label = str(id) + ' ' + label label = str(idx)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param = {} param = {}
if hasattr(node.op, 'view_map') and id in reduce( if label:
param['label'] = label
if hasattr(node.op, 'view_map') and idx in reduce(
list.__add__, node.op.view_map.values(), []): list.__add__, node.op.view_map.values(), []):
param['color'] = 'blue' param['color'] = 'blue'
elif hasattr(node.op, 'destroy_map') and id in reduce( elif hasattr(node.op, 'destroy_map') and idx in reduce(
list.__add__, node.op.destroy_map.values(), []): list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red' param['color'] = 'red'
if var.owner is None: if var.owner is None:
color = 'green'
if isinstance(var, SharedVariable):
# Input are green, output blue
# Mixing blue and green give cyan! (input and output var)
color = "cyan"
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, g.add_node(pd.Node(varid,
style='filled', style='filled',
fillcolor='green', fillcolor=color,
label=varstr,
shape=var_shape)) shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='green', shape=var_shape)) g.add_node(pd.Node(varid,
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) color=color,
elif var.name or not compact: label=varstr,
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) shape=var_shape))
g.add_edge(pd.Edge(varid, aid, **param))
elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varid, aid, **param))
else: else:
# no name, so we don't make a var ellipse # no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner), astr, if label:
label=label, **param)) label += " "
label += str(var.type)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param['label'] = label
g.add_edge(pd.Edge(apply_name(var.owner)[1], aid, **param))
for id, var in enumerate(node.outputs): for idx, var in enumerate(node.outputs):
varstr = var_name(var) varstr, varid = var_name(var)
out = var in outputs out = var in outputs
label = str(var.type) label = ""
if len(node.outputs) > 1: if len(node.outputs) > 1:
label = str(id) + ' ' + label label = str(idx)
if len(label) > max_label_size: if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...' label = label[:max_label_size - 3] + '...'
if out: param = {}
g.add_edge(pd.Edge(astr, varstr, label=label)) if label:
param['label'] = label
if out or var in input_update:
g.add_edge(pd.Edge(aid, varid, **param))
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, style='filled', g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='blue', shape=var_shape)) fillcolor='blue', shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='blue', shape=var_shape)) g.add_node(pd.Node(varid, color='blue',
label=varstr,
shape=var_shape))
elif len(var.clients) == 0: elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(aid, varid, **param))
# grey mean that output var isn't used
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, style='filled', g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='grey', shape=var_shape)) fillcolor='grey', shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape)) g.add_node(pd.Node(varid, label=varstr,
color='grey', shape=var_shape))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label)) if not(not compact):
if label:
label += " "
label += str(var.type)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param['label'] = label
g.add_edge(pd.Edge(aid, varid, **param))
g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
# else: # else:
# don't add egde here as it is already added from the inputs. # don't add egde here as it is already added from the inputs.
# The var that represent updates, must be linked to the input var.
for sha, up in input_update.items():
_, shaid = var_name(sha)
_, upid = var_name(up)
g.add_edge(pd.Edge(shaid, upid, label="UPDATE", color="blue"))
if cond_highlight: if cond_highlight:
g.add_subgraph(c1) g.add_subgraph(c1)
g.add_subgraph(c2) g.add_subgraph(c2)
...@@ -928,9 +970,6 @@ def pydotprint(fct, outfile=None, ...@@ -928,9 +970,6 @@ def pydotprint(fct, outfile=None,
if not outfile.endswith('.' + format): if not outfile.endswith('.' + format):
outfile += '.' + format outfile += '.' + format
if assert_nb_all_strings != -1:
assert len(all_strings) == assert_nb_all_strings, len(all_strings)
if scan_graphs: if scan_graphs:
scan_ops = [(idx, x) for idx, x in enumerate(topo) scan_ops = [(idx, x) for idx, x in enumerate(topo)
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
......
...@@ -107,12 +107,10 @@ def test_pydotprint_long_name(): ...@@ -107,12 +107,10 @@ def test_pydotprint_long_name():
f([1, 2, 3, 4]) f([1, 2, 3, 4])
theano.printing.pydotprint(f, max_label_size=5, theano.printing.pydotprint(f, max_label_size=5,
print_output_file=False, print_output_file=False)
assert_nb_all_strings=6)
theano.printing.pydotprint([x * 2, x + x], theano.printing.pydotprint([x * 2, x + x],
max_label_size=5, max_label_size=5,
print_output_file=False, print_output_file=False)
assert_nb_all_strings=8)
def test_pydotprint_profile(): def test_pydotprint_profile():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论