提交 246a557b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

The new version of pydotprinting that has an extra option high_contrast that

allows you to color the nodes with colors. You can optionally provide a dictionary of colors to the function.
上级 ffd02ba4
...@@ -373,8 +373,20 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n ...@@ -373,8 +373,20 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint pp = pprint
default_colorCodes = {'GpuFromHost' : 'red',
'HostFromGpu' : 'red',
'Scan' : 'yellow',
'Shape' : 'cyan',
'Cond' : 'magenta',
'Elemwise': '#FFAABB',
'Subtensor': '#FFAAFF'}
def pydotprint(fct, outfile=None, def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False): compact=True, format='png', with_ids=False,
high_contrast=False, cond_highlight = None, colorCodes = None):
""" """
print to a file in png format the graph of op of a compile theano fct. print to a file in png format the graph of op of a compile theano fct.
...@@ -382,6 +394,11 @@ def pydotprint(fct, outfile=None, ...@@ -382,6 +394,11 @@ def pydotprint(fct, outfile=None,
:param outfile: the output file where to put the graph. :param outfile: the output file where to put the graph.
:param compact: if True, will remove intermediate var that don't have name. :param compact: if True, will remove intermediate var that don't have name.
:param format: the file format of the output. :param format: the file format of the output.
:param high_contrast: if true, the color that describes the respective
node is filled with its corresponding color, instead of coloring
the border
:param colorCodes: dictionary with names of ops as keys and colors as
values
In the graph, box are an Apply Node(the execution of an op) and ellipse are variable. In the graph, box are an Apply Node(the execution of an op) and ellipse are variable.
If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph). If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph).
...@@ -395,12 +412,25 @@ def pydotprint(fct, outfile=None, ...@@ -395,12 +412,25 @@ def pydotprint(fct, outfile=None,
red ellipses are transfer to/from the gpu. red ellipses are transfer to/from the gpu.
op with those name GpuFromHost, HostFromGpu op with those name GpuFromHost, HostFromGpu
""" """
if colorCodes is None:
colorCodes = default_colorCodes
if outfile is None: if outfile is None:
outfile = os.path.join(config.compiledir,'theano.pydotprint.' + outfile = os.path.join(config.compiledir,'theano.pydotprint.' +
config.device + '.' + format) config.device + '.' + format)
if isinstance(fct, Function):
mode = fct.maker.mode mode = fct.maker.mode
fct_env = fct.maker.env
if not isinstance(mode,ProfileMode) or not mode.fct_call.has_key(fct): if not isinstance(mode,ProfileMode) or not mode.fct_call.has_key(fct):
mode=None
elif isinstance(fct, gof.Env):
mode = None mode = None
fct_env = fct
else:
raise ValueError(('pydotprint expects as input a theano.function or'
'the env of a function!'), fct)
try: try:
import pydot as pd import pydot as pd
except: except:
...@@ -408,8 +438,36 @@ def pydotprint(fct, outfile=None, ...@@ -408,8 +438,36 @@ def pydotprint(fct, outfile=None,
return return
g=pd.Dot() g=pd.Dot()
if cond_highlight is not None:
c1 = pd.Cluster('Left')
c2 = pd.Cluster('Right')
c3 = pd.Cluster('Middle')
for node in fct_env.toposort():
if node.op.__class__.__name__=='Cond' and node.op.name == cond_highlight:
cond = node
def recursive_pass(x,ls):
if not x.owner:
return ls
else:
ls += [x.owner]
for inp in x.inputs:
ls += recursive_pass(inp, ls)
return ls
left = set(recursive_pass(cond.inputs[1],[]))
right =set(recursive_pass(cond.inputs[2],[]))
middle = left.intersecton(right)
left = left.difference(middle)
right = right.difference(middle)
middle = list(middle)
left = list(middle)
right = list(middle)
var_str={} var_str={}
all_strings = set() all_strings = set()
def var_name(var): def var_name(var):
if var in var_str: if var in var_str:
return var_str[var] return var_str[var]
...@@ -434,7 +492,7 @@ def pydotprint(fct, outfile=None, ...@@ -434,7 +492,7 @@ def pydotprint(fct, outfile=None,
all_strings.add(varstr) all_strings.add(varstr)
return varstr return varstr
topo = fct.maker.env.toposort() topo = fct_env.toposort()
apply_name_cache = {} apply_name_cache = {}
def apply_name(node): def apply_name(node):
if node in apply_name_cache: if node in apply_name_cache:
...@@ -460,7 +518,8 @@ def pydotprint(fct, outfile=None, ...@@ -460,7 +518,8 @@ def pydotprint(fct, outfile=None,
# Update the inputs that have an update function # Update the inputs that have an update function
input_update={} input_update={}
outputs = list(fct.maker.env.outputs) outputs = list(fct_env.outputs)
if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs): for i in reversed(fct.maker.expanded_inputs):
if i.update is not None: if i.update is not None:
input_update[outputs.pop()] = i input_update[outputs.pop()] = i
...@@ -470,11 +529,27 @@ def pydotprint(fct, outfile=None, ...@@ -470,11 +529,27 @@ def pydotprint(fct, outfile=None,
for node_idx,node in enumerate(topo): for node_idx,node in enumerate(topo):
astr=apply_name(node) astr=apply_name(node)
if node.op.__class__.__name__ in ('GpuFromHost','HostFromGpu'): use_color = None
# highlight CPU-GPU transfers to simplify optimization for opName, color in colorCodes.items():
g.add_node(pd.Node(astr,color='red',shape=apply_shape)) if opName in node.op.__class__.__name__:
use_color = color
if use_color is None:
nw_node = pd.Node(astr, shape=apply_shape)
elif high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color,
shape = apply_shape)
else: else:
g.add_node(pd.Node(astr,shape=apply_shape)) nw_node = pd.Node(astr,color=use_color, shape = apply_shape)
g.add_node(nw_node)
if cond_highlight:
if node in middle:
c3.add_node(nw_node)
elif node in left:
c1.add_node(nw_node)
elif node in right:
c2.add_node(nw_node)
for id,var in enumerate(node.inputs): for id,var in enumerate(node.inputs):
varstr=var_name(var) varstr=var_name(var)
...@@ -482,6 +557,11 @@ def pydotprint(fct, outfile=None, ...@@ -482,6 +557,11 @@ def pydotprint(fct, outfile=None,
if len(node.inputs)>1: if len(node.inputs)>1:
label=str(id)+' '+label label=str(id)+' '+label
if var.owner is None: if var.owner is None:
if high_contrast:
g.add_node(pd.Node(varstr
,style = 'filled'
, fillcolor='green',shape=var_shape))
else:
g.add_node(pd.Node(varstr,color='green',shape=var_shape)) g.add_node(pd.Node(varstr,color='green',shape=var_shape))
g.add_edge(pd.Edge(varstr,astr, label=label)) g.add_edge(pd.Edge(varstr,astr, label=label))
elif var.name or not compact: elif var.name or not compact:
...@@ -499,14 +579,28 @@ def pydotprint(fct, outfile=None, ...@@ -499,14 +579,28 @@ def pydotprint(fct, outfile=None,
label=str(id)+' '+label label=str(id)+' '+label
if out: if out:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast:
g.add_node(pd.Node(varstr,style='filled'
,fillcolor='blue',shape=var_shape))
else:
g.add_node(pd.Node(varstr,color='blue',shape=var_shape)) g.add_node(pd.Node(varstr,color='blue',shape=var_shape))
elif len(var.clients)==0: elif len(var.clients)==0:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast:
g.add_node(pd.Node(varstr,style='filled',
fillcolor='grey',shape=var_shape))
else:
g.add_node(pd.Node(varstr,color='grey',shape=var_shape)) g.add_node(pd.Node(varstr,color='grey',shape=var_shape))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
# else: # else:
#don't add egde here as it is already added from the inputs. #don't add egde here as it is already added from the inputs.
if cond_highlight:
g.add_subgraph(c1)
g.add_subgraph(c2)
g.add_subgraph(c3)
if not outfile.endswith('.'+format): if not outfile.endswith('.'+format):
outfile+='.'+format outfile+='.'+format
g.write(outfile, prog='dot', format=format) g.write(outfile, prog='dot', format=format)
...@@ -516,7 +610,10 @@ def pydotprint(fct, outfile=None, ...@@ -516,7 +610,10 @@ def pydotprint(fct, outfile=None,
def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), depth = -1): def pydotprint_variables(vars,
outfile=os.path.join(config.compiledir,'theano.pydotprint.png'),
depth = -1,
high_contrast = True):
''' Identical to pydotprint just that it starts from a variable instead ''' Identical to pydotprint just that it starts from a variable instead
of a compiled function. Could be useful ? ''' of a compiled function. Could be useful ? '''
try: try:
...@@ -526,6 +623,7 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -526,6 +623,7 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
return return
g=pd.Dot() g=pd.Dot()
my_list = {} my_list = {}
orphanes = []
if type(vars) not in (list,tuple): if type(vars) not in (list,tuple):
vars = [vars] vars = [vars]
var_str = {} var_str = {}
...@@ -559,12 +657,32 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -559,12 +657,32 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
return return
astr = apply_name(app) + '_' + str(len(my_list.keys())) astr = apply_name(app) + '_' + str(len(my_list.keys()))
my_list[app] = astr my_list[app] = astr
use_color = None
for opName, color in colorCodes.items():
if opName in app.op.__class__.__name__ :
use_color = color
if use_color is None:
g.add_node(pd.Node(astr, shape='box')) g.add_node(pd.Node(astr, shape='box'))
elif high_contrast:
g.add_node(pd.Node(astr, style='filled', fillcolor=use_color,
shape = 'box'))
else:
g.add_node(pd.Nonde(astr,color=use_color, shape = 'box'))
for i,nd in enumerate(app.inputs): for i,nd in enumerate(app.inputs):
if nd not in my_list: if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys())) varastr = var_name(nd) + '_' + str(len(my_list.keys()))
my_list[nd] = varastr my_list[nd] = varastr
if nd.owner is not None:
g.add_node(pd.Node(varastr)) g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style ='filled',
fillcolor='green'))
else:
g.add_node(pd.Node(varastr, color='green'))
else: else:
varastr = my_list[nd] varastr = my_list[nd]
label = '' label = ''
...@@ -576,7 +694,18 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -576,7 +694,18 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
if nd not in my_list: if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys())) varastr = var_name(nd) + '_' + str(len(my_list.keys()))
my_list[nd] = varastr my_list[nd] = varastr
color = None
if nd in vars:
color = 'blue'
elif nd in orphanes :
color = 'gray'
if color is None:
g.add_node(pd.Node(varastr)) g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style='filled',
fillcolor=color))
else:
g.add_node(pd.Node(varastr, color = color))
else: else:
varastr = my_list[nd] varastr = my_list[nd]
label = '' label = ''
...@@ -588,6 +717,12 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -588,6 +717,12 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
plot_apply(nd.owner, d-1) plot_apply(nd.owner, d-1)
for nd in vars:
if nd.owner:
for k in nd.owner.outputs:
if k not in vars:
orphanes.append(k)
for nd in vars: for nd in vars:
if nd.owner: if nd.owner:
plot_apply(nd.owner, depth) plot_apply(nd.owner, depth)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论