提交 b73758ee authored 作者: Frederic's avatar Frederic

Fix pydotprint update of shared var. Make the a new color cyan.

上级 b7f4e2c3
......@@ -571,7 +571,7 @@ Print to the terminal a math-like expression.
default_colorCodes = {'GpuFromHost': 'red',
'HostFromGpu': 'red',
'Scan': 'yellow',
'Shape': 'cyan',
'Shape': 'brown',
'IfElse': 'magenta',
'Elemwise': '#FFAABB', # dark pink
'Subtensor': '#FFAAFF', # purple
......@@ -642,11 +642,20 @@ def pydotprint(fct, outfile=None,
label each edge between an input and the Apply node with the
input's index.
Green boxes are inputs variables to the graph,
blue boxes are outputs variables of the graph,
grey boxes are variables that are not outputs and are not used,
red ellipses are transfers from/to the gpu (ops with names GpuFromHost,
HostFromGpu).
Variable color code::
- Cyan boxes are SharedVariable, inputs and/or outputs) of the graph,
- Green boxes are inputs variables to the graph,
- Blue boxes are outputs variables of the graph,
- Grey boxes are variables that are not outputs and are not used,
Default apply node code::
- Red ellipses are transfers from/to the gpu
- Yellow are scan node
- Brown are shape node
- Magenta are IfElse node
- Dark pink are elemwise node
- Purple are subtensor
- Orange are alloc node
For edges, they are black by default. If a node returns a view
of an input, we put the corresponding input edge in blue. If it
......@@ -749,11 +758,11 @@ def pydotprint(fct, outfile=None,
dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type))
elif (var in input_update and
input_update[var].variable.name is not None):
input_update[var].name is not None):
if var_with_name_simple:
varstr = input_update[var].variable.name + " UPDATE"
varstr = input_update[var].variable.name
else:
varstr = (input_update[var].variable.name + " UPDATE " +
varstr = (input_update[var].variable.name +
str(var.type))
else:
# a var id is needed as otherwise var with the same type will be
......@@ -775,6 +784,11 @@ def pydotprint(fct, outfile=None,
'...' +
suffix)
var_str[var] = varstr
# The var that represent the new value, must be linked to the
# input var.
if var in input_update.values():
var_str[reverse_input_update[var]] = varstr
all_strings.add(varstr)
return varstr
......@@ -829,13 +843,18 @@ def pydotprint(fct, outfile=None,
# Update the inputs that have an update function
input_update = {}
reverse_input_update = {}
# Here outputs can be the original list, as we should not change
# it, we must copy it.
outputs = list(outputs)
if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs):
for i, fg_ii in reversed(zip(fct.maker.expanded_inputs,
fct.maker.fgraph.inputs)):
if i.update is not None:
input_update[outputs.pop()] = i
k = outputs.pop()
# Use the fgaph.inputs as it isn't the same as maker.inputs
input_update[k] = fg_ii
reverse_input_update[fg_ii] = k
apply_shape = 'ellipse'
var_shape = 'box'
......@@ -878,13 +897,18 @@ def pydotprint(fct, outfile=None,
list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red'
if var.owner is None:
color = 'green'
if isinstance(var, theano.compile.SharedVariable):
# Input are green, output blue
# Mixing blue and green give cyan! (input and output var)
color = "cyan"
if high_contrast:
g.add_node(pd.Node(varstr,
style='filled',
fillcolor='green',
fillcolor=color,
shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='green', shape=var_shape))
g.add_node(pd.Node(varstr, color=color, shape=var_shape))
g.add_edge(pd.Edge(varstr, astr, label=label, **param))
elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varstr, astr, label=label, **param))
......@@ -910,12 +934,13 @@ def pydotprint(fct, outfile=None,
g.add_node(pd.Node(varstr, color='blue', shape=var_shape))
elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, label=label))
# grey mean that output var isn't used
if high_contrast:
g.add_node(pd.Node(varstr, style='filled',
fillcolor='grey', shape=var_shape))
else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape))
elif var.name or not compact:
elif var.name or not compact or var in input_update:
g.add_edge(pd.Edge(astr, varstr, label=label))
# else:
# don't add egde here as it is already added from the inputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论