提交 77523c8e authored 作者: Frederic's avatar Frederic

code cleanup/refactoring

上级 ada1a89a
...@@ -589,11 +589,13 @@ def pydotprint(fct, outfile=None, ...@@ -589,11 +589,13 @@ def pydotprint(fct, outfile=None,
if (not isinstance(mode, ProfileMode) if (not isinstance(mode, ProfileMode)
or not fct in mode.profile_stats): or not fct in mode.profile_stats):
mode = None mode = None
fct_fgraph = fct.maker.fgraph outputs = fct.maker.fgraph.outputs
topo = fct.maker.fgraph.toposort()
elif isinstance(fct, gof.FunctionGraph): elif isinstance(fct, gof.FunctionGraph):
mode = None mode = None
profile = None profile = None
fct_fgraph = fct outputs = fct.outputs
topo = fct.toposort()
else: else:
raise ValueError(('pydotprint expects as input a theano.function or ' raise ValueError(('pydotprint expects as input a theano.function or '
'the FunctionGraph of a function!'), fct) 'the FunctionGraph of a function!'), fct)
...@@ -604,12 +606,13 @@ def pydotprint(fct, outfile=None, ...@@ -604,12 +606,13 @@ def pydotprint(fct, outfile=None,
return return
g = pd.Dot() g = pd.Dot()
if cond_highlight is not None: if cond_highlight is not None:
c1 = pd.Cluster('Left') c1 = pd.Cluster('Left')
c2 = pd.Cluster('Right') c2 = pd.Cluster('Right')
c3 = pd.Cluster('Middle') c3 = pd.Cluster('Middle')
cond = None cond = None
for node in fct_fgraph.toposort(): for node in topo:
if (node.op.__class__.__name__ == 'IfElse' if (node.op.__class__.__name__ == 'IfElse'
and node.op.name == cond_highlight): and node.op.name == cond_highlight):
cond = node cond = node
...@@ -684,7 +687,6 @@ def pydotprint(fct, outfile=None, ...@@ -684,7 +687,6 @@ def pydotprint(fct, outfile=None,
all_strings.add(varstr) all_strings.add(varstr)
return varstr return varstr
topo = fct_fgraph.toposort()
apply_name_cache = {} apply_name_cache = {}
def apply_name(node): def apply_name(node):
...@@ -736,7 +738,6 @@ def pydotprint(fct, outfile=None, ...@@ -736,7 +738,6 @@ def pydotprint(fct, outfile=None,
# Update the inputs that have an update function # Update the inputs that have an update function
input_update = {} input_update = {}
outputs = list(fct_fgraph.outputs)
if isinstance(fct, Function): if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs): for i in reversed(fct.maker.expanded_inputs):
if i.update is not None: if i.update is not None:
...@@ -792,7 +793,7 @@ def pydotprint(fct, outfile=None, ...@@ -792,7 +793,7 @@ def pydotprint(fct, outfile=None,
for id, var in enumerate(node.outputs): for id, var in enumerate(node.outputs):
varstr = var_name(var) varstr = var_name(var)
out = any([x[0] == 'output' for x in var.clients]) out = var in outputs
label = str(var.type) label = str(var.type)
if len(node.outputs) > 1: if len(node.outputs) > 1:
label = str(id) + ' ' + label label = str(id) + ' ' + label
...@@ -830,10 +831,10 @@ def pydotprint(fct, outfile=None, ...@@ -830,10 +831,10 @@ def pydotprint(fct, outfile=None,
print 'The output file is available at', outfile print 'The output file is available at', outfile
if assert_nb_all_strings != -1: if assert_nb_all_strings != -1:
assert len(all_strings) == assert_nb_all_strings assert len(all_strings) == assert_nb_all_strings, len(all_strings)
if scan_graphs: if scan_graphs:
scan_ops = [(idx, x) for idx, x in enumerate(fct_fgraph.toposort()) scan_ops = [(idx, x) for idx, x in enumerate(topo)
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
path, fn = os.path.split(outfile) path, fn = os.path.split(outfile)
basename = '.'.join(fn.split('.')[:-1]) basename = '.'.join(fn.split('.')[:-1])
......
...@@ -65,8 +65,6 @@ def test_pydotprint_variables(): ...@@ -65,8 +65,6 @@ def test_pydotprint_variables():
new_handler.setLevel(logging.DEBUG) new_handler.setLevel(logging.DEBUG)
orig_handler = theano.logging_default_handler orig_handler = theano.logging_default_handler
theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
theano.theano_logger.removeHandler(orig_handler) theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler) theano.theano_logger.addHandler(new_handler)
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论