提交 f5788c09 authored 作者: Frederic Bastien's avatar Frederic Bastien

added fct theano.printing.pydotprint() that write to a png file a compiled theano fct.

上级 8c95b5bf
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
""" """
import gof import gof
from copy import copy from copy import copy
import sys import sys,os
from theano import config
from gof import Op, Apply from gof import Op, Apply
from theano.gof.python25 import any
class Print(Op): class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs """This identity-like Op has the side effect of printing a message followed by its inputs
...@@ -299,3 +300,53 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n ...@@ -299,3 +300,53 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint pp = pprint
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png')):
"""
print to a file in png format the graph of op of a compile theano fct.
:param fct: the theano fct returned by theano.function.
:param outfile: the output file where to put the graph.
In the graph, box are an Apply Node(the execution of an op) and elipse are variable.
If variable have name they are used as the text(if multiple var have the same name, they will be merged in the graph). Otherwise, if a constant, we print the value and finaly we print the type + an uniq number to don't have multiple var merged.
We print the op of the apply in the Apply box with a number that represent the toposort order of application of those Apply.
"""
import pydot as pd
g=pd.Dot()
def var_name(var):
if var.name is not None:
varstr = var.name
elif isinstance(var,gof.Constant):
varstr = str(var.data)
else:
#a var id is needed as otherwise var with the same type will be merged in the graph.
varstr = str(var.type)+' '+str(id(var))
return varstr
for node_idx,node in enumerate(fct.maker.env.toposort()):
astr=str(node.op).replace(':','_')+' '+str(node_idx)
g.add_node(pd.Node(astr,shape='box'))
for var in node.inputs:
varstr=var_name(var)
if var.owner is None:
g.add_node(pd.Node(varstr,color='green'))
g.add_edge(pd.Edge(varstr,astr))
for var in node.outputs:
varstr=var_name(var)
g.add_edge(pd.Edge(astr,varstr))
if any([x[0]=='output' for x in var.env.clients(var)]):
g.add_node(pd.Node(varstr,color='blue'))
g.write_png(outfile, prog='dot')
print 'The output file is available at',outfile
#from matplotlib import pyplot
#image=pyplot.imread(outfile)
#pyplot.imshow(image)
#import pdb;pdb.set_trace()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论