提交 bd59e402 authored 作者: Frederic Bastien's avatar Frederic Bastien

in pydotprint, don't merge var with the same name. shared variable in update…

in pydotprint, don't merge var with the same name. shared variable in update have distinct ellipse for their input and output in the graph. Both will have the same name with a different suffix.
上级 c311436c
...@@ -321,21 +321,32 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -321,21 +321,32 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
import pydot as pd import pydot as pd
g=pd.Dot() g=pd.Dot()
var_id={} var_str={}
def var_name(var): def var_name(var):
if var in var_str:
return var_str[var]
if var.name is not None: if var.name is not None:
varstr = var.name varstr = var.name
elif isinstance(var,gof.Constant): elif isinstance(var,gof.Constant):
varstr = str(var.data) varstr = str(var.data)
elif var in input_update and input_update[var].variable.name is not None:
varstr = input_update[var].variable.name
else: else:
#a var id is needed as otherwise var with the same type will be merged in the graph. #a var id is needed as otherwise var with the same type will be merged in the graph.
i = var_id.get(var,None) varstr = str(var.type)
if i is None: varstr += ' ' + str(len(var_str))
var_id[var]=len(var_id) var_str[var]=varstr
i = var_id[var]
varstr = str(var.type)+' '+str(i)
return varstr return varstr
# Update the inputs that have an update function
input_update={}
outputs = list(fct.maker.env.outputs)
for i in reversed(fct.maker.expanded_inputs):
if i.update is not None:
input_update[outputs.pop()] = i
for node_idx,node in enumerate(fct.maker.env.toposort()): for node_idx,node in enumerate(fct.maker.env.toposort()):
astr=str(node.op).replace(':','_')+' '+str(node_idx) astr=str(node.op).replace(':','_')+' '+str(node_idx)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论