提交 83b27034 authored 作者: James Bergstra's avatar James Bergstra

tweaks to pydotprint for paper

上级 08578830
...@@ -359,7 +359,8 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n ...@@ -359,7 +359,8 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint pp = pprint
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), compact=True, mode=None, format='png'): def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'),
compact=True, mode=None, format='png', with_ids=False):
""" """
print to a file in png format the graph of op of a compile theano fct. print to a file in png format the graph of op of a compile theano fct.
...@@ -390,14 +391,15 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -390,14 +391,15 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
g=pd.Dot() g=pd.Dot()
var_str={} var_str={}
all_strings = set()
def var_name(var): def var_name(var):
if var in var_str: if var in var_str:
return var_str[var] return var_str[var]
if var.name is not None: if var.name is not None:
varstr = var.name+" "+str(var.type) varstr = 'name='+var.name+" "+str(var.type)
elif isinstance(var,gof.Constant): elif isinstance(var,gof.Constant):
dstr = str(var.data) dstr = 'val='+str(var.data)
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30: if len(dstr) > 30:
...@@ -408,12 +410,17 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -408,12 +410,17 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
else: else:
#a var id is needed as otherwise var with the same type will be merged in the graph. #a var id is needed as otherwise var with the same type will be merged in the graph.
varstr = str(var.type) varstr = str(var.type)
varstr += ' ' + str(len(var_str)) if (varstr in all_strings) or with_ids:
varstr += ' id=' + str(len(var_str))
var_str[var]=varstr var_str[var]=varstr
all_strings.add(varstr)
return varstr return varstr
topo = fct.maker.env.toposort() topo = fct.maker.env.toposort()
apply_name_cache = {}
def apply_name(node): def apply_name(node):
if node in apply_name_cache:
return apply_name_cache[node]
prof_str='' prof_str=''
if mode: if mode:
time = mode.apply_time.get((topo.index(node),node),0) time = mode.apply_time.get((topo.index(node),node),0)
...@@ -425,7 +432,12 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -425,7 +432,12 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
pf=0 pf=0
else: pf = time*100/mode.fct_call_time[fct] else: pf = time*100/mode.fct_call_time[fct]
prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf) prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf)
return str(node.op).replace(':','_')+' '+str(topo.index(node))+prof_str applystr = str(node.op).replace(':','_')
if (applystr in all_strings) or with_ids:
applystr = applystr+' id='+str(topo.index(node))+prof_str
all_strings.add(applystr)
apply_name_cache[node] = applystr
return applystr
# Update the inputs that have an update function # Update the inputs that have an update function
input_update={} input_update={}
...@@ -434,16 +446,18 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -434,16 +446,18 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
if i.update is not None: if i.update is not None:
input_update[outputs.pop()] = i input_update[outputs.pop()] = i
apply_shape='ellipse'
var_shape='box'
for node_idx,node in enumerate(topo): for node_idx,node in enumerate(topo):
astr=apply_name(node) astr=apply_name(node)
g.add_node(pd.Node(astr,shape='box')) g.add_node(pd.Node(astr,shape=apply_shape))
for id,var in enumerate(node.inputs): for id,var in enumerate(node.inputs):
varstr=var_name(var) varstr=var_name(var)
label='' label=''
if len(node.inputs)>1: if len(node.inputs)>1:
label=str(id) label=str(id)
if var.owner is None: if var.owner is None:
g.add_node(pd.Node(varstr,color='green')) g.add_node(pd.Node(varstr,color='green',shape=var_shape))
g.add_edge(pd.Edge(varstr,astr, label=label)) g.add_edge(pd.Edge(varstr,astr, label=label))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(varstr,astr, label=label)) g.add_edge(pd.Edge(varstr,astr, label=label))
...@@ -460,10 +474,10 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -460,10 +474,10 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
label=str(id) label=str(id)
if out: if out:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='blue')) g.add_node(pd.Node(varstr,color='blue',shape=var_shape))
elif len(var.clients)==0: elif len(var.clients)==0:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='grey')) g.add_node(pd.Node(varstr,color='grey',shape=var_shape))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
# else: # else:
...@@ -495,9 +509,9 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -495,9 +509,9 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
return var_str[var] return var_str[var]
if var.name is not None: if var.name is not None:
varstr = var.name varstr = 'name='+var.name
elif isinstance(var,gof.Constant): elif isinstance(var,gof.Constant):
dstr = str(var.data) dstr = 'val='+str(var.data)
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30: if len(dstr) > 30:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论