提交 b3bb3006 authored 作者: Frederic Bastien's avatar Frederic Bastien

in pydotprint(), added a parameter compact that default to True.

If True, won't print the intermediate var that don't have name.
上级 a035ef7a
......@@ -347,13 +347,14 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png')):
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), compact=True):
"""
print to a file in png format the graph of op of a compile theano fct.
:param fct: the theano fct returned by theano.function.
:param outfile: the output file where to put the graph.
:param compact: if True, will remove intermediate var that don't have name.
In the graph, box are an Apply Node(the execution of an op) and ellipse are variable.
If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph).
Otherwise, if the variable is constant, we print the value and finaly we print the type + an uniq number to don't have multiple var merged.
......@@ -361,8 +362,11 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
green ellipse are input to the graph and blue ellipse are output of the graph.
"""
import pydot as pd
try:
import pydot as pd
except:
print "failed to import pydot. Yous must install pydot for this function to work."
g=pd.Dot()
var_str={}
def var_name(var):
......@@ -382,6 +386,9 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
var_str[var]=varstr
return varstr
topo = fct.maker.env.toposort()
def apply_name(node):
return str(node.op).replace(':','_')+' '+str(topo.index(node))
# Update the inputs that have an update function
input_update={}
......@@ -390,29 +397,37 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
if i.update is not None:
input_update[outputs.pop()] = i
for node_idx,node in enumerate(fct.maker.env.toposort()):
astr=str(node.op).replace(':','_')+' '+str(node_idx)
for node_idx,node in enumerate(topo):
astr=apply_name(node)
g.add_node(pd.Node(astr,shape='box'))
for var in node.inputs:
varstr=var_name(var)
if var.owner is None:
g.add_node(pd.Node(varstr,color='green'))
g.add_edge(pd.Edge(varstr,astr))
g.add_edge(pd.Edge(varstr,astr))
elif var.name or not compact:
g.add_edge(pd.Edge(varstr,astr))
else:
#no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner),astr))
for var in node.outputs:
varstr=var_name(var)
g.add_edge(pd.Edge(astr,varstr))
if any([x[0]=='output' for x in var.env.clients(var)]):
g.add_node(pd.Node(varstr,color='blue'))
out = any([x[0]=='output' for x in var.clients])
if out:
g.add_edge(pd.Edge(astr,varstr))
g.add_node(pd.Node(varstr,color='blue'))
elif var.name or not compact:
g.add_edge(pd.Edge(astr,varstr))
else:
#no name, so we don't make a var ellipse
for client in var.clients:
edge = pd.Edge(astr,apply_name(client[0]))
g.add_edge(edge)
g.set_simplify(True)
g.write_png(outfile, prog='dot')
print 'The output file is available at',outfile
#from matplotlib import pyplot
#image=pyplot.imread(outfile)
#pyplot.imshow(image)
#import pdb;pdb.set_trace()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论