提交 83b27034 authored 作者: James Bergstra's avatar James Bergstra

tweaks to pydotprint for paper

上级 08578830
......@@ -359,7 +359,8 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), compact=True, mode=None, format='png'):
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'),
compact=True, mode=None, format='png', with_ids=False):
"""
print to a file in png format the graph of op of a compile theano fct.
......@@ -390,14 +391,15 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
g=pd.Dot()
var_str={}
all_strings = set()
def var_name(var):
if var in var_str:
return var_str[var]
if var.name is not None:
varstr = var.name+" "+str(var.type)
varstr = 'name='+var.name+" "+str(var.type)
elif isinstance(var,gof.Constant):
dstr = str(var.data)
dstr = 'val='+str(var.data)
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30:
......@@ -408,12 +410,17 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
else:
#a var id is needed as otherwise var with the same type will be merged in the graph.
varstr = str(var.type)
varstr += ' ' + str(len(var_str))
if (varstr in all_strings) or with_ids:
varstr += ' id=' + str(len(var_str))
var_str[var]=varstr
all_strings.add(varstr)
return varstr
topo = fct.maker.env.toposort()
apply_name_cache = {}
def apply_name(node):
if node in apply_name_cache:
return apply_name_cache[node]
prof_str=''
if mode:
time = mode.apply_time.get((topo.index(node),node),0)
......@@ -425,7 +432,12 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
pf=0
else: pf = time*100/mode.fct_call_time[fct]
prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf)
return str(node.op).replace(':','_')+' '+str(topo.index(node))+prof_str
applystr = str(node.op).replace(':','_')
if (applystr in all_strings) or with_ids:
applystr = applystr+' id='+str(topo.index(node))+prof_str
all_strings.add(applystr)
apply_name_cache[node] = applystr
return applystr
# Update the inputs that have an update function
input_update={}
......@@ -434,16 +446,18 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
if i.update is not None:
input_update[outputs.pop()] = i
apply_shape='ellipse'
var_shape='box'
for node_idx,node in enumerate(topo):
astr=apply_name(node)
g.add_node(pd.Node(astr,shape='box'))
g.add_node(pd.Node(astr,shape=apply_shape))
for id,var in enumerate(node.inputs):
varstr=var_name(var)
label=''
if len(node.inputs)>1:
label=str(id)
if var.owner is None:
g.add_node(pd.Node(varstr,color='green'))
g.add_node(pd.Node(varstr,color='green',shape=var_shape))
g.add_edge(pd.Edge(varstr,astr, label=label))
elif var.name or not compact:
g.add_edge(pd.Edge(varstr,astr, label=label))
......@@ -460,10 +474,10 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
label=str(id)
if out:
g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='blue'))
g.add_node(pd.Node(varstr,color='blue',shape=var_shape))
elif len(var.clients)==0:
g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='grey'))
g.add_node(pd.Node(varstr,color='grey',shape=var_shape))
elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label))
# else:
......@@ -495,9 +509,9 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
return var_str[var]
if var.name is not None:
varstr = var.name
varstr = 'name='+var.name
elif isinstance(var,gof.Constant):
dstr = str(var.data)
dstr = 'val='+str(var.data)
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论