提交 2e699b53 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5237 from Sentient07/fix-small-error

Small deprecation fix
...@@ -8,7 +8,6 @@ from copy import copy ...@@ -8,7 +8,6 @@ from copy import copy
import logging import logging
import os import os
import sys import sys
import warnings
import hashlib import hashlib
import numpy as np import numpy as np
...@@ -1080,175 +1079,6 @@ def pydotprint(fct, outfile=None, ...@@ -1080,175 +1079,6 @@ def pydotprint(fct, outfile=None,
print('The output file is available at', outfile) print('The output file is available at', outfile)
def pydotprint_variables(vars,
outfile=None,
format='png',
depth=-1,
high_contrast=True, colorCodes=None,
max_label_size=50,
var_with_name_simple=False):
'''DEPRECATED: use pydotprint() instead.
Identical to pydotprint just that it starts from a variable
instead of a compiled function. Could be useful ?
'''
warnings.warn("pydotprint_variables() is deprecated."
" Use pydotprint() instead.")
if colorCodes is None:
colorCodes = default_colorCodes
if outfile is None:
outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format)
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" and graphviz for `pydotprint_variables` to work.",
pydot_imported_msg)
if pd.__name__ == "pydot_ng":
raise RuntimeError("pydotprint_variables do not support pydot_ng."
"pydotprint_variables is also deprecated, "
"use pydotprint() that support pydot_ng")
g = pd.Dot()
my_list = {}
orphanes = []
if type(vars) not in (list, tuple):
vars = [vars]
var_str = {}
def var_name(var):
if var in var_str:
return var_str[var]
if var.name is not None:
if var_with_name_simple:
varstr = var.name
else:
varstr = 'name=' + var.name + " " + str(var.type)
elif isinstance(var, gof.Constant):
dstr = 'val=' + str(var.data)
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type))
else:
# a var id is needed as otherwise var with the same type will be
# merged in the graph.
varstr = str(var.type)
varstr += ' ' + str(len(var_str))
if len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...'
var_str[var] = varstr
return varstr
def apply_name(node):
name = str(node.op).replace(':', '_')
if len(name) > max_label_size:
name = name[:max_label_size - 3] + '...'
return name
def plot_apply(app, d):
if d == 0:
return
if app in my_list:
return
astr = apply_name(app) + '_' + str(len(my_list.keys()))
if len(astr) > max_label_size:
astr = astr[:max_label_size - 3] + '...'
my_list[app] = astr
use_color = None
for opName, color in iteritems(colorCodes):
if opName in app.op.__class__.__name__:
use_color = color
if use_color is None:
g.add_node(pd.Node(astr, shape='box'))
elif high_contrast:
g.add_node(pd.Node(astr, style='filled', fillcolor=use_color,
shape='box'))
else:
g.add_node(pd.Nonde(astr, color=use_color, shape='box'))
for i, nd in enumerate(app.inputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size - 3] + '...'
my_list[nd] = varastr
if nd.owner is not None:
g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style='filled',
fillcolor='green'))
else:
g.add_node(pd.Node(varastr, color='green'))
else:
varastr = my_list[nd]
label = None
if len(app.inputs) > 1:
label = str(i)
g.add_edge(pd.Edge(varastr, astr, label=label))
for i, nd in enumerate(app.outputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size - 3] + '...'
my_list[nd] = varastr
color = None
if nd in vars:
color = colorCodes['Output']
elif nd in orphanes:
color = 'gray'
if color is None:
g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style='filled',
fillcolor=color))
else:
g.add_node(pd.Node(varastr, color=color))
else:
varastr = my_list[nd]
label = None
if len(app.outputs) > 1:
label = str(i)
g.add_edge(pd.Edge(astr, varastr, label=label))
for nd in app.inputs:
if nd.owner:
plot_apply(nd.owner, d - 1)
for nd in vars:
if nd.owner:
for k in nd.owner.outputs:
if k not in vars:
orphanes.append(k)
for nd in vars:
if nd.owner:
plot_apply(nd.owner, depth)
try:
g.write(outfile, prog='dot', format=format)
except pd.InvocationException as e:
# Some version of pydot are bugged/don't work correctly with
# empty label. Provide a better user error message.
version = getattr(pd, '__version__', "")
if version == "1.0.28" and "label=]" in e.message:
raise Exception("pydot 1.0.28 is know to be bugged. Use another "
"working version of pydot")
elif "label=]" in e.message:
raise Exception("Your version of pydot " + version +
" returned an error. Version 1.0.28 is known"
" to be bugged and 1.0.25 to be working with"
" Theano. Using another version of pydot could"
" fix this problem. The pydot error is: " +
e.message)
raise
print('The output file is available at', outfile)
class _TagGenerator: class _TagGenerator:
""" Class for giving abbreviated tags like to objects. """ Class for giving abbreviated tags like to objects.
Only really intended for internal use in order to Only really intended for internal use in order to
......
...@@ -58,37 +58,6 @@ def test_pydotprint_return_image(): ...@@ -58,37 +58,6 @@ def test_pydotprint_return_image():
assert isinstance(ret, (str, bytes)) assert isinstance(ret, (str, bytes))
def test_pydotprint_variables():
"""
This is a REALLY PARTIAL TEST.
I did them to help debug stuff.
It make sure the code run.
"""
# Skip test if pydot is not available.
if not theano.printing.pydot_imported:
raise SkipTest('pydot not available')
x = tensor.dvector()
s = StringIO()
new_handler = logging.StreamHandler(s)
new_handler.setLevel(logging.DEBUG)
orig_handler = theano.logging_default_handler
theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
try:
theano.printing.pydotprint(x * 2)
if not theano.printing.pd.__name__ == "pydot_ng":
theano.printing.pydotprint_variables(x * 2)
finally:
theano.theano_logger.addHandler(orig_handler)
theano.theano_logger.removeHandler(new_handler)
def test_pydotprint_long_name(): def test_pydotprint_long_name():
"""This is a REALLY PARTIAL TEST. """This is a REALLY PARTIAL TEST.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论