提交 6485dbbc authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3213 from nouiz/pydotprint

Pydotprint, fix printing of reused output and updated shared var
......@@ -28,7 +28,7 @@ from theano import gof
from theano import config
from six.moves import StringIO, reduce
from theano.gof import Op, Apply
from theano.compile import Function, debugmode
from theano.compile import Function, debugmode, SharedVariable
from theano.compile.profilemode import ProfileMode
_logger = logging.getLogger("theano.printing")
......@@ -571,7 +571,7 @@ Print to the terminal a math-like expression.
default_colorCodes = {'GpuFromHost': 'red',
'HostFromGpu': 'red',
'Scan': 'yellow',
'Shape': 'cyan',
'Shape': 'brown',
'IfElse': 'magenta',
'Elemwise': '#FFAABB', # dark pink
'Subtensor': '#FFAAFF', # purple
......@@ -584,7 +584,6 @@ def pydotprint(fct, outfile=None,
max_label_size=70, scan_graphs=False,
var_with_name_simple=False,
print_output_file=True,
assert_nb_all_strings=-1,
return_image=False,
):
"""Print to a file the graph of a compiled theano function's ops. Supports
......@@ -616,10 +615,6 @@ def pydotprint(fct, outfile=None,
:param var_with_name_simple: If true and a variable have a name,
we will print only the variable name.
Otherwise, we concatenate the type to the var name.
:param assert_nb_all_strings: Used for tests. If non-negative, assert that
the number of unique string nodes in the dot graph is equal to
this number. This is used in tests to verify that dot won't
merge Theano nodes.
:param return_image: If True, it will create the image and return it.
Useful to display the image in ipython notebook.
......@@ -642,11 +637,20 @@ def pydotprint(fct, outfile=None,
label each edge between an input and the Apply node with the
input's index.
Green boxes are inputs variables to the graph,
blue boxes are outputs variables of the graph,
grey boxes are variables that are not outputs and are not used,
red ellipses are transfers from/to the gpu (ops with names GpuFromHost,
HostFromGpu).
Variable color code::
- Cyan boxes are SharedVariable, inputs and/or outputs) of the graph,
- Green boxes are inputs variables to the graph,
- Blue boxes are outputs variables of the graph,
- Grey boxes are variables that are not outputs and are not used,
Default apply node code::
- Red ellipses are transfers from/to the gpu
- Yellow are scan node
- Brown are shape node
- Magenta are IfElse node
- Dark pink are elemwise node
- Purple are subtensor
- Orange are alloc node
For edges, they are black by default. If a node returns a view
of an input, we put the corresponding input edge in blue. If it
......@@ -732,11 +736,12 @@ def pydotprint(fct, outfile=None,
right = list(right)
var_str = {}
var_id = {}
all_strings = set()
def var_name(var):
if var in var_str:
return var_str[var]
return var_str[var], var_id[var]
if var.name is not None:
if var_with_name_simple:
......@@ -749,40 +754,31 @@ def pydotprint(fct, outfile=None,
dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type))
elif (var in input_update and
input_update[var].variable.name is not None):
input_update[var].name is not None):
if var_with_name_simple:
varstr = input_update[var].variable.name + " UPDATE"
varstr = input_update[var].variable.name
else:
varstr = (input_update[var].variable.name + " UPDATE " +
varstr = (input_update[var].variable.name +
str(var.type))
else:
# a var id is needed as otherwise var with the same type will be
# merged in the graph.
varstr = str(var.type)
if (varstr in all_strings) or with_ids:
idx = ' id=' + str(len(var_str))
if len(varstr) + len(idx) > max_label_size:
varstr = varstr[:max_label_size - 3 - len(idx)] + idx + '...'
else:
varstr = varstr + idx
elif len(varstr) > max_label_size:
if len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...'
idx = 1
while varstr in all_strings:
idx += 1
suffix = ' id=' + str(idx)
varstr = (varstr[:max_label_size - 3 - len(suffix)] +
'...' +
suffix)
var_str[var] = varstr
var_id[var] = str(id(var))
all_strings.add(varstr)
return varstr
return varstr, var_id[var]
apply_name_cache = {}
apply_name_id = {}
def apply_name(node):
if node in apply_name_cache:
return apply_name_cache[node]
return apply_name_cache[node], apply_name_id[node]
prof_str = ''
if mode:
time = mode.profile_stats[fct].apply_time.get(node, 0)
......@@ -825,22 +821,29 @@ def pydotprint(fct, outfile=None,
all_strings.add(applystr)
apply_name_cache[node] = applystr
return applystr
apply_name_id[node] = str(id(node))
return applystr, apply_name_id[node]
# Update the inputs that have an update function
input_update = {}
reverse_input_update = {}
# Here outputs can be the original list, as we should not change
# it, we must copy it.
outputs = list(outputs)
if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs):
for i, fg_ii in reversed(zip(fct.maker.expanded_inputs,
fct.maker.fgraph.inputs)):
if i.update is not None:
input_update[outputs.pop()] = i
k = outputs.pop()
# Use the fgaph.inputs as it isn't the same as maker.inputs
input_update[k] = fg_ii
reverse_input_update[fg_ii] = k
apply_shape = 'ellipse'
var_shape = 'box'
for node_idx, node in enumerate(topo):
astr = apply_name(node)
astr, aid = apply_name(node)
use_color = None
for opName, color in iteritems(colorCodes):
......@@ -848,12 +851,14 @@ def pydotprint(fct, outfile=None,
use_color = color
if use_color is None:
nw_node = pd.Node(astr, shape=apply_shape)
nw_node = pd.Node(aid, label=astr, shape=apply_shape)
elif high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color,
nw_node = pd.Node(aid, label=astr, style='filled',
fillcolor=use_color,
shape=apply_shape)
else:
nw_node = pd.Node(astr, color=use_color, shape=apply_shape)
nw_node = pd.Node(aid, label=astr,
color=use_color, shape=apply_shape)
g.add_node(nw_node)
if cond_highlight:
if node in middle:
......@@ -863,63 +868,100 @@ def pydotprint(fct, outfile=None,
elif node in right:
c2.add_node(nw_node)
for id, var in enumerate(node.inputs):
varstr = var_name(var)
label = str(var.type)
for idx, var in enumerate(node.inputs):
varstr, varid = var_name(var)
label = ""
if len(node.inputs) > 1:
label = str(id) + ' ' + label
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
label = str(idx)
param = {}
if hasattr(node.op, 'view_map') and id in reduce(
if label:
param['label'] = label
if hasattr(node.op, 'view_map') and idx in reduce(
list.__add__, node.op.view_map.values(), []):
param['color'] = 'blue'
elif hasattr(node.op, 'destroy_map') and id in reduce(
elif hasattr(node.op, 'destroy_map') and idx in reduce(
list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red'
if var.owner is None:
color = 'green'
if isinstance(var, SharedVariable):
# Input are green, output blue
# Mixing blue and green give cyan! (input and output var)
color = "cyan"
if high_contrast:
g.add_node(pd.Node(varstr,
g.add_node(pd.Node(varid,
style='filled',
fillcolor='green',
fillcolor=color,
label=varstr,
shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='green', shape=var_shape))
g.add_edge(pd.Edge(varstr, astr, label=label, **param))
elif var.name or not compact:
g.add_edge(pd.Edge(varstr, astr, label=label, **param))
g.add_node(pd.Node(varid,
color=color,
label=varstr,
shape=var_shape))
g.add_edge(pd.Edge(varid, aid, **param))
elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varid, aid, **param))
else:
# no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner), astr,
label=label, **param))
for id, var in enumerate(node.outputs):
varstr = var_name(var)
if label:
label += " "
label += str(var.type)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param['label'] = label
g.add_edge(pd.Edge(apply_name(var.owner)[1], aid, **param))
for idx, var in enumerate(node.outputs):
varstr, varid = var_name(var)
out = var in outputs
label = str(var.type)
label = ""
if len(node.outputs) > 1:
label = str(id) + ' ' + label
label = str(idx)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
if out:
g.add_edge(pd.Edge(astr, varstr, label=label))
param = {}
if label:
param['label'] = label
if out or var in input_update:
g.add_edge(pd.Edge(aid, varid, **param))
if high_contrast:
g.add_node(pd.Node(varstr, style='filled',
g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='blue', shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='blue', shape=var_shape))
g.add_node(pd.Node(varid, color='blue',
label=varstr,
shape=var_shape))
elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_edge(pd.Edge(aid, varid, **param))
# grey mean that output var isn't used
if high_contrast:
g.add_node(pd.Node(varstr, style='filled',
g.add_node(pd.Node(varid, style='filled',
label=varstr,
fillcolor='grey', shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape))
g.add_node(pd.Node(varid, label=varstr,
color='grey', shape=var_shape))
elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label))
if not(not compact):
if label:
label += " "
label += str(var.type)
if len(label) > max_label_size:
label = label[:max_label_size - 3] + '...'
param['label'] = label
g.add_edge(pd.Edge(aid, varid, **param))
g.add_node(pd.Node(varid, shape=var_shape, label=varstr))
# else:
# don't add egde here as it is already added from the inputs.
# The var that represent updates, must be linked to the input var.
for sha, up in input_update.items():
_, shaid = var_name(sha)
_, upid = var_name(up)
g.add_edge(pd.Edge(shaid, upid, label="UPDATE", color="blue"))
if cond_highlight:
g.add_subgraph(c1)
g.add_subgraph(c2)
......@@ -928,9 +970,6 @@ def pydotprint(fct, outfile=None,
if not outfile.endswith('.' + format):
outfile += '.' + format
if assert_nb_all_strings != -1:
assert len(all_strings) == assert_nb_all_strings, len(all_strings)
if scan_graphs:
scan_ops = [(idx, x) for idx, x in enumerate(topo)
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
......
......@@ -107,12 +107,10 @@ def test_pydotprint_long_name():
f([1, 2, 3, 4])
theano.printing.pydotprint(f, max_label_size=5,
print_output_file=False,
assert_nb_all_strings=6)
print_output_file=False)
theano.printing.pydotprint([x * 2, x + x],
max_label_size=5,
print_output_file=False,
assert_nb_all_strings=8)
print_output_file=False)
def test_pydotprint_profile():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论