提交 b73758ee authored 作者: Frederic's avatar Frederic

Fix pydotprint update of shared var. Make the a new color cyan.

上级 b7f4e2c3
...@@ -571,7 +571,7 @@ Print to the terminal a math-like expression. ...@@ -571,7 +571,7 @@ Print to the terminal a math-like expression.
default_colorCodes = {'GpuFromHost': 'red', default_colorCodes = {'GpuFromHost': 'red',
'HostFromGpu': 'red', 'HostFromGpu': 'red',
'Scan': 'yellow', 'Scan': 'yellow',
'Shape': 'cyan', 'Shape': 'brown',
'IfElse': 'magenta', 'IfElse': 'magenta',
'Elemwise': '#FFAABB', # dark pink 'Elemwise': '#FFAABB', # dark pink
'Subtensor': '#FFAAFF', # purple 'Subtensor': '#FFAAFF', # purple
...@@ -642,11 +642,20 @@ def pydotprint(fct, outfile=None, ...@@ -642,11 +642,20 @@ def pydotprint(fct, outfile=None,
label each edge between an input and the Apply node with the label each edge between an input and the Apply node with the
input's index. input's index.
Green boxes are inputs variables to the graph, Variable color code::
blue boxes are outputs variables of the graph, - Cyan boxes are SharedVariable, inputs and/or outputs) of the graph,
grey boxes are variables that are not outputs and are not used, - Green boxes are inputs variables to the graph,
red ellipses are transfers from/to the gpu (ops with names GpuFromHost, - Blue boxes are outputs variables of the graph,
HostFromGpu). - Grey boxes are variables that are not outputs and are not used,
Default apply node code::
- Red ellipses are transfers from/to the gpu
- Yellow are scan node
- Brown are shape node
- Magenta are IfElse node
- Dark pink are elemwise node
- Purple are subtensor
- Orange are alloc node
For edges, they are black by default. If a node returns a view For edges, they are black by default. If a node returns a view
of an input, we put the corresponding input edge in blue. If it of an input, we put the corresponding input edge in blue. If it
...@@ -749,11 +758,11 @@ def pydotprint(fct, outfile=None, ...@@ -749,11 +758,11 @@ def pydotprint(fct, outfile=None,
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type)) varstr = '%s %s' % (dstr, str(var.type))
elif (var in input_update and elif (var in input_update and
input_update[var].variable.name is not None): input_update[var].name is not None):
if var_with_name_simple: if var_with_name_simple:
varstr = input_update[var].variable.name + " UPDATE" varstr = input_update[var].variable.name
else: else:
varstr = (input_update[var].variable.name + " UPDATE " + varstr = (input_update[var].variable.name +
str(var.type)) str(var.type))
else: else:
# a var id is needed as otherwise var with the same type will be # a var id is needed as otherwise var with the same type will be
...@@ -775,6 +784,11 @@ def pydotprint(fct, outfile=None, ...@@ -775,6 +784,11 @@ def pydotprint(fct, outfile=None,
'...' + '...' +
suffix) suffix)
var_str[var] = varstr var_str[var] = varstr
# The var that represent the new value, must be linked to the
# input var.
if var in input_update.values():
var_str[reverse_input_update[var]] = varstr
all_strings.add(varstr) all_strings.add(varstr)
return varstr return varstr
...@@ -829,13 +843,18 @@ def pydotprint(fct, outfile=None, ...@@ -829,13 +843,18 @@ def pydotprint(fct, outfile=None,
# Update the inputs that have an update function # Update the inputs that have an update function
input_update = {} input_update = {}
reverse_input_update = {}
# Here outputs can be the original list, as we should not change # Here outputs can be the original list, as we should not change
# it, we must copy it. # it, we must copy it.
outputs = list(outputs) outputs = list(outputs)
if isinstance(fct, Function): if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs): for i, fg_ii in reversed(zip(fct.maker.expanded_inputs,
fct.maker.fgraph.inputs)):
if i.update is not None: if i.update is not None:
input_update[outputs.pop()] = i k = outputs.pop()
# Use the fgaph.inputs as it isn't the same as maker.inputs
input_update[k] = fg_ii
reverse_input_update[fg_ii] = k
apply_shape = 'ellipse' apply_shape = 'ellipse'
var_shape = 'box' var_shape = 'box'
...@@ -878,13 +897,18 @@ def pydotprint(fct, outfile=None, ...@@ -878,13 +897,18 @@ def pydotprint(fct, outfile=None,
list.__add__, node.op.destroy_map.values(), []): list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red' param['color'] = 'red'
if var.owner is None: if var.owner is None:
color = 'green'
if isinstance(var, theano.compile.SharedVariable):
# Input are green, output blue
# Mixing blue and green give cyan! (input and output var)
color = "cyan"
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, g.add_node(pd.Node(varstr,
style='filled', style='filled',
fillcolor='green', fillcolor=color,
shape=var_shape)) shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='green', shape=var_shape)) g.add_node(pd.Node(varstr, color=color, shape=var_shape))
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) g.add_edge(pd.Edge(varstr, astr, label=label, **param))
elif var.name or not compact or var in outputs: elif var.name or not compact or var in outputs:
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) g.add_edge(pd.Edge(varstr, astr, label=label, **param))
...@@ -910,12 +934,13 @@ def pydotprint(fct, outfile=None, ...@@ -910,12 +934,13 @@ def pydotprint(fct, outfile=None,
g.add_node(pd.Node(varstr, color='blue', shape=var_shape)) g.add_node(pd.Node(varstr, color='blue', shape=var_shape))
elif len(var.clients) == 0: elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
# grey mean that output var isn't used
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, style='filled', g.add_node(pd.Node(varstr, style='filled',
fillcolor='grey', shape=var_shape)) fillcolor='grey', shape=var_shape))
else: else:
g.add_node(pd.Node(varstr, color='grey', shape=var_shape)) g.add_node(pd.Node(varstr, color='grey', shape=var_shape))
elif var.name or not compact: elif var.name or not compact or var in input_update:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
# else: # else:
# don't add egde here as it is already added from the inputs. # don't add egde here as it is already added from the inputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论