提交 b3bb3006 authored 作者: Frederic Bastien's avatar Frederic Bastien

in pydotprint(), added a parameter compact that default to True.

If True, won't print the intermediate var that don't have name.
上级 a035ef7a
...@@ -347,12 +347,13 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n ...@@ -347,12 +347,13 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint pp = pprint
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png')): def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), compact=True):
""" """
print to a file in png format the graph of op of a compile theano fct. print to a file in png format the graph of op of a compile theano fct.
:param fct: the theano fct returned by theano.function. :param fct: the theano fct returned by theano.function.
:param outfile: the output file where to put the graph. :param outfile: the output file where to put the graph.
:param compact: if True, will remove intermediate var that don't have name.
In the graph, box are an Apply Node(the execution of an op) and ellipse are variable. In the graph, box are an Apply Node(the execution of an op) and ellipse are variable.
If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph). If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph).
...@@ -361,7 +362,10 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -361,7 +362,10 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
green ellipse are input to the graph and blue ellipse are output of the graph. green ellipse are input to the graph and blue ellipse are output of the graph.
""" """
try:
import pydot as pd import pydot as pd
except:
print "failed to import pydot. Yous must install pydot for this function to work."
g=pd.Dot() g=pd.Dot()
var_str={} var_str={}
...@@ -382,6 +386,9 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -382,6 +386,9 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
var_str[var]=varstr var_str[var]=varstr
return varstr return varstr
topo = fct.maker.env.toposort()
def apply_name(node):
return str(node.op).replace(':','_')+' '+str(topo.index(node))
# Update the inputs that have an update function # Update the inputs that have an update function
input_update={} input_update={}
...@@ -390,29 +397,37 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -390,29 +397,37 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
if i.update is not None: if i.update is not None:
input_update[outputs.pop()] = i input_update[outputs.pop()] = i
for node_idx,node in enumerate(fct.maker.env.toposort()): for node_idx,node in enumerate(topo):
astr=str(node.op).replace(':','_')+' '+str(node_idx) astr=apply_name(node)
g.add_node(pd.Node(astr,shape='box')) g.add_node(pd.Node(astr,shape='box'))
for var in node.inputs: for var in node.inputs:
varstr=var_name(var) varstr=var_name(var)
if var.owner is None: if var.owner is None:
g.add_node(pd.Node(varstr,color='green')) g.add_node(pd.Node(varstr,color='green'))
g.add_edge(pd.Edge(varstr,astr)) g.add_edge(pd.Edge(varstr,astr))
elif var.name or not compact:
g.add_edge(pd.Edge(varstr,astr))
else:
#no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner),astr))
for var in node.outputs: for var in node.outputs:
varstr=var_name(var) varstr=var_name(var)
out = any([x[0]=='output' for x in var.clients])
if out:
g.add_edge(pd.Edge(astr,varstr)) g.add_edge(pd.Edge(astr,varstr))
if any([x[0]=='output' for x in var.env.clients(var)]):
g.add_node(pd.Node(varstr,color='blue')) g.add_node(pd.Node(varstr,color='blue'))
elif var.name or not compact:
g.add_edge(pd.Edge(astr,varstr))
else:
#no name, so we don't make a var ellipse
for client in var.clients:
edge = pd.Edge(astr,apply_name(client[0]))
g.add_edge(edge)
g.set_simplify(True)
g.write_png(outfile, prog='dot') g.write_png(outfile, prog='dot')
print 'The output file is available at',outfile print 'The output file is available at',outfile
#from matplotlib import pyplot
#image=pyplot.imread(outfile)
#pyplot.imshow(image)
#import pdb;pdb.set_trace()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论