提交 246a557b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

The new version of pydotprinting that has an extra option high_contrast that

allows you to color the nodes with colors. You can optionally provide a dictionary of colors to the function.
上级 ffd02ba4
......@@ -373,8 +373,20 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint
default_colorCodes = {'GpuFromHost' : 'red',
'HostFromGpu' : 'red',
'Scan' : 'yellow',
'Shape' : 'cyan',
'Cond' : 'magenta',
'Elemwise': '#FFAABB',
'Subtensor': '#FFAAFF'}
def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False):
compact=True, format='png', with_ids=False,
high_contrast=False, cond_highlight = None, colorCodes = None):
"""
print to a file in png format the graph of op of a compile theano fct.
......@@ -382,6 +394,11 @@ def pydotprint(fct, outfile=None,
:param outfile: the output file where to put the graph.
:param compact: if True, will remove intermediate var that don't have name.
:param format: the file format of the output.
:param high_contrast: if true, the color that describes the respective
node is filled with its corresponding color, instead of coloring
the border
:param colorCodes: dictionary with names of ops as keys and colors as
values
In the graph, box are an Apply Node(the execution of an op) and ellipse are variable.
If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph).
......@@ -395,12 +412,25 @@ def pydotprint(fct, outfile=None,
red ellipses are transfer to/from the gpu.
op with those name GpuFromHost, HostFromGpu
"""
if colorCodes is None:
colorCodes = default_colorCodes
if outfile is None:
outfile = os.path.join(config.compiledir,'theano.pydotprint.' +
config.device + '.' + format)
mode = fct.maker.mode
if not isinstance(mode,ProfileMode) or not mode.fct_call.has_key(fct):
if isinstance(fct, Function):
mode = fct.maker.mode
fct_env = fct.maker.env
if not isinstance(mode,ProfileMode) or not mode.fct_call.has_key(fct):
mode=None
elif isinstance(fct, gof.Env):
mode = None
fct_env = fct
else:
raise ValueError(('pydotprint expects as input a theano.function or'
'the env of a function!'), fct)
try:
import pydot as pd
except:
......@@ -408,8 +438,36 @@ def pydotprint(fct, outfile=None,
return
g=pd.Dot()
if cond_highlight is not None:
c1 = pd.Cluster('Left')
c2 = pd.Cluster('Right')
c3 = pd.Cluster('Middle')
for node in fct_env.toposort():
if node.op.__class__.__name__=='Cond' and node.op.name == cond_highlight:
cond = node
def recursive_pass(x,ls):
if not x.owner:
return ls
else:
ls += [x.owner]
for inp in x.inputs:
ls += recursive_pass(inp, ls)
return ls
left = set(recursive_pass(cond.inputs[1],[]))
right =set(recursive_pass(cond.inputs[2],[]))
middle = left.intersecton(right)
left = left.difference(middle)
right = right.difference(middle)
middle = list(middle)
left = list(middle)
right = list(middle)
var_str={}
all_strings = set()
def var_name(var):
if var in var_str:
return var_str[var]
......@@ -434,7 +492,7 @@ def pydotprint(fct, outfile=None,
all_strings.add(varstr)
return varstr
topo = fct.maker.env.toposort()
topo = fct_env.toposort()
apply_name_cache = {}
def apply_name(node):
if node in apply_name_cache:
......@@ -460,21 +518,38 @@ def pydotprint(fct, outfile=None,
# Update the inputs that have an update function
input_update={}
outputs = list(fct.maker.env.outputs)
for i in reversed(fct.maker.expanded_inputs):
if i.update is not None:
input_update[outputs.pop()] = i
outputs = list(fct_env.outputs)
if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs):
if i.update is not None:
input_update[outputs.pop()] = i
apply_shape='ellipse'
var_shape='box'
for node_idx,node in enumerate(topo):
astr=apply_name(node)
if node.op.__class__.__name__ in ('GpuFromHost','HostFromGpu'):
# highlight CPU-GPU transfers to simplify optimization
g.add_node(pd.Node(astr,color='red',shape=apply_shape))
use_color = None
for opName, color in colorCodes.items():
if opName in node.op.__class__.__name__:
use_color = color
if use_color is None:
nw_node = pd.Node(astr, shape=apply_shape)
elif high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color,
shape = apply_shape)
else:
g.add_node(pd.Node(astr,shape=apply_shape))
nw_node = pd.Node(astr,color=use_color, shape = apply_shape)
g.add_node(nw_node)
if cond_highlight:
if node in middle:
c3.add_node(nw_node)
elif node in left:
c1.add_node(nw_node)
elif node in right:
c2.add_node(nw_node)
for id,var in enumerate(node.inputs):
varstr=var_name(var)
......@@ -482,7 +557,12 @@ def pydotprint(fct, outfile=None,
if len(node.inputs)>1:
label=str(id)+' '+label
if var.owner is None:
g.add_node(pd.Node(varstr,color='green',shape=var_shape))
if high_contrast:
g.add_node(pd.Node(varstr
,style = 'filled'
, fillcolor='green',shape=var_shape))
else:
g.add_node(pd.Node(varstr,color='green',shape=var_shape))
g.add_edge(pd.Edge(varstr,astr, label=label))
elif var.name or not compact:
g.add_edge(pd.Edge(varstr,astr, label=label))
......@@ -499,14 +579,28 @@ def pydotprint(fct, outfile=None,
label=str(id)+' '+label
if out:
g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='blue',shape=var_shape))
if high_contrast:
g.add_node(pd.Node(varstr,style='filled'
,fillcolor='blue',shape=var_shape))
else:
g.add_node(pd.Node(varstr,color='blue',shape=var_shape))
elif len(var.clients)==0:
g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='grey',shape=var_shape))
if high_contrast:
g.add_node(pd.Node(varstr,style='filled',
fillcolor='grey',shape=var_shape))
else:
g.add_node(pd.Node(varstr,color='grey',shape=var_shape))
elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label))
# else:
#don't add egde here as it is already added from the inputs.
if cond_highlight:
g.add_subgraph(c1)
g.add_subgraph(c2)
g.add_subgraph(c3)
if not outfile.endswith('.'+format):
outfile+='.'+format
g.write(outfile, prog='dot', format=format)
......@@ -516,7 +610,10 @@ def pydotprint(fct, outfile=None,
def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), depth = -1):
def pydotprint_variables(vars,
outfile=os.path.join(config.compiledir,'theano.pydotprint.png'),
depth = -1,
high_contrast = True):
''' Identical to pydotprint just that it starts from a variable instead
of a compiled function. Could be useful ? '''
try:
......@@ -526,6 +623,7 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
return
g=pd.Dot()
my_list = {}
orphanes = []
if type(vars) not in (list,tuple):
vars = [vars]
var_str = {}
......@@ -559,12 +657,32 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
return
astr = apply_name(app) + '_' + str(len(my_list.keys()))
my_list[app] = astr
g.add_node(pd.Node(astr, shape='box'))
use_color = None
for opName, color in colorCodes.items():
if opName in app.op.__class__.__name__ :
use_color = color
if use_color is None:
g.add_node(pd.Node(astr, shape='box'))
elif high_contrast:
g.add_node(pd.Node(astr, style='filled', fillcolor=use_color,
shape = 'box'))
else:
g.add_node(pd.Nonde(astr,color=use_color, shape = 'box'))
for i,nd in enumerate(app.inputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
my_list[nd] = varastr
g.add_node(pd.Node(varastr))
if nd.owner is not None:
g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style ='filled',
fillcolor='green'))
else:
g.add_node(pd.Node(varastr, color='green'))
else:
varastr = my_list[nd]
label = ''
......@@ -576,7 +694,18 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
my_list[nd] = varastr
g.add_node(pd.Node(varastr))
color = None
if nd in vars:
color = 'blue'
elif nd in orphanes :
color = 'gray'
if color is None:
g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style='filled',
fillcolor=color))
else:
g.add_node(pd.Node(varastr, color = color))
else:
varastr = my_list[nd]
label = ''
......@@ -588,6 +717,12 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
plot_apply(nd.owner, d-1)
for nd in vars:
if nd.owner:
for k in nd.owner.outputs:
if k not in vars:
orphanes.append(k)
for nd in vars:
if nd.owner:
plot_apply(nd.owner, depth)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论