提交 77523c8e authored 作者: Frederic's avatar Frederic

code cleanup/refactoring

上级 ada1a89a
......@@ -589,11 +589,13 @@ def pydotprint(fct, outfile=None,
if (not isinstance(mode, ProfileMode)
or not fct in mode.profile_stats):
mode = None
fct_fgraph = fct.maker.fgraph
outputs = fct.maker.fgraph.outputs
topo = fct.maker.fgraph.toposort()
elif isinstance(fct, gof.FunctionGraph):
mode = None
profile = None
fct_fgraph = fct
outputs = fct.outputs
topo = fct.toposort()
else:
raise ValueError(('pydotprint expects as input a theano.function or '
'the FunctionGraph of a function!'), fct)
......@@ -604,12 +606,13 @@ def pydotprint(fct, outfile=None,
return
g = pd.Dot()
if cond_highlight is not None:
c1 = pd.Cluster('Left')
c2 = pd.Cluster('Right')
c3 = pd.Cluster('Middle')
cond = None
for node in fct_fgraph.toposort():
for node in topo:
if (node.op.__class__.__name__ == 'IfElse'
and node.op.name == cond_highlight):
cond = node
......@@ -684,7 +687,6 @@ def pydotprint(fct, outfile=None,
all_strings.add(varstr)
return varstr
topo = fct_fgraph.toposort()
apply_name_cache = {}
def apply_name(node):
......@@ -736,7 +738,6 @@ def pydotprint(fct, outfile=None,
# Update the inputs that have an update function
input_update = {}
outputs = list(fct_fgraph.outputs)
if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs):
if i.update is not None:
......@@ -792,7 +793,7 @@ def pydotprint(fct, outfile=None,
for id, var in enumerate(node.outputs):
varstr = var_name(var)
out = any([x[0] == 'output' for x in var.clients])
out = var in outputs
label = str(var.type)
if len(node.outputs) > 1:
label = str(id) + ' ' + label
......@@ -830,10 +831,10 @@ def pydotprint(fct, outfile=None,
print 'The output file is available at', outfile
if assert_nb_all_strings != -1:
assert len(all_strings) == assert_nb_all_strings
assert len(all_strings) == assert_nb_all_strings, len(all_strings)
if scan_graphs:
scan_ops = [(idx, x) for idx, x in enumerate(fct_fgraph.toposort())
scan_ops = [(idx, x) for idx, x in enumerate(topo)
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
path, fn = os.path.split(outfile)
basename = '.'.join(fn.split('.')[:-1])
......
......@@ -65,8 +65,6 @@ def test_pydotprint_variables():
new_handler.setLevel(logging.DEBUG)
orig_handler = theano.logging_default_handler
theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论