提交 2e699b53 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5237 from Sentient07/fix-small-error

Small deprecation fix
......@@ -8,7 +8,6 @@ from copy import copy
import logging
import os
import sys
import warnings
import hashlib
import numpy as np
......@@ -1080,175 +1079,6 @@ def pydotprint(fct, outfile=None,
print('The output file is available at', outfile)
def pydotprint_variables(vars,
outfile=None,
format='png',
depth=-1,
high_contrast=True, colorCodes=None,
max_label_size=50,
var_with_name_simple=False):
'''DEPRECATED: use pydotprint() instead.
Identical to pydotprint just that it starts from a variable
instead of a compiled function. Could be useful ?
'''
warnings.warn("pydotprint_variables() is deprecated."
" Use pydotprint() instead.")
if colorCodes is None:
colorCodes = default_colorCodes
if outfile is None:
outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format)
if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot"
" and graphviz for `pydotprint_variables` to work.",
pydot_imported_msg)
if pd.__name__ == "pydot_ng":
raise RuntimeError("pydotprint_variables do not support pydot_ng."
"pydotprint_variables is also deprecated, "
"use pydotprint() that support pydot_ng")
g = pd.Dot()
my_list = {}
orphanes = []
if type(vars) not in (list, tuple):
vars = [vars]
var_str = {}
def var_name(var):
if var in var_str:
return var_str[var]
if var.name is not None:
if var_with_name_simple:
varstr = var.name
else:
varstr = 'name=' + var.name + " " + str(var.type)
elif isinstance(var, gof.Constant):
dstr = 'val=' + str(var.data)
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type))
else:
# a var id is needed as otherwise var with the same type will be
# merged in the graph.
varstr = str(var.type)
varstr += ' ' + str(len(var_str))
if len(varstr) > max_label_size:
varstr = varstr[:max_label_size - 3] + '...'
var_str[var] = varstr
return varstr
def apply_name(node):
name = str(node.op).replace(':', '_')
if len(name) > max_label_size:
name = name[:max_label_size - 3] + '...'
return name
def plot_apply(app, d):
if d == 0:
return
if app in my_list:
return
astr = apply_name(app) + '_' + str(len(my_list.keys()))
if len(astr) > max_label_size:
astr = astr[:max_label_size - 3] + '...'
my_list[app] = astr
use_color = None
for opName, color in iteritems(colorCodes):
if opName in app.op.__class__.__name__:
use_color = color
if use_color is None:
g.add_node(pd.Node(astr, shape='box'))
elif high_contrast:
g.add_node(pd.Node(astr, style='filled', fillcolor=use_color,
shape='box'))
else:
g.add_node(pd.Nonde(astr, color=use_color, shape='box'))
for i, nd in enumerate(app.inputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size - 3] + '...'
my_list[nd] = varastr
if nd.owner is not None:
g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style='filled',
fillcolor='green'))
else:
g.add_node(pd.Node(varastr, color='green'))
else:
varastr = my_list[nd]
label = None
if len(app.inputs) > 1:
label = str(i)
g.add_edge(pd.Edge(varastr, astr, label=label))
for i, nd in enumerate(app.outputs):
if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size:
varastr = varastr[:max_label_size - 3] + '...'
my_list[nd] = varastr
color = None
if nd in vars:
color = colorCodes['Output']
elif nd in orphanes:
color = 'gray'
if color is None:
g.add_node(pd.Node(varastr))
elif high_contrast:
g.add_node(pd.Node(varastr, style='filled',
fillcolor=color))
else:
g.add_node(pd.Node(varastr, color=color))
else:
varastr = my_list[nd]
label = None
if len(app.outputs) > 1:
label = str(i)
g.add_edge(pd.Edge(astr, varastr, label=label))
for nd in app.inputs:
if nd.owner:
plot_apply(nd.owner, d - 1)
for nd in vars:
if nd.owner:
for k in nd.owner.outputs:
if k not in vars:
orphanes.append(k)
for nd in vars:
if nd.owner:
plot_apply(nd.owner, depth)
try:
g.write(outfile, prog='dot', format=format)
except pd.InvocationException as e:
# Some version of pydot are bugged/don't work correctly with
# empty label. Provide a better user error message.
version = getattr(pd, '__version__', "")
if version == "1.0.28" and "label=]" in e.message:
raise Exception("pydot 1.0.28 is know to be bugged. Use another "
"working version of pydot")
elif "label=]" in e.message:
raise Exception("Your version of pydot " + version +
" returned an error. Version 1.0.28 is known"
" to be bugged and 1.0.25 to be working with"
" Theano. Using another version of pydot could"
" fix this problem. The pydot error is: " +
e.message)
raise
print('The output file is available at', outfile)
class _TagGenerator:
""" Class for giving abbreviated tags like to objects.
Only really intended for internal use in order to
......
......@@ -58,37 +58,6 @@ def test_pydotprint_return_image():
assert isinstance(ret, (str, bytes))
def test_pydotprint_variables():
"""
This is a REALLY PARTIAL TEST.
I did them to help debug stuff.
It make sure the code run.
"""
# Skip test if pydot is not available.
if not theano.printing.pydot_imported:
raise SkipTest('pydot not available')
x = tensor.dvector()
s = StringIO()
new_handler = logging.StreamHandler(s)
new_handler.setLevel(logging.DEBUG)
orig_handler = theano.logging_default_handler
theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
try:
theano.printing.pydotprint(x * 2)
if not theano.printing.pd.__name__ == "pydot_ng":
theano.printing.pydotprint_variables(x * 2)
finally:
theano.theano_logger.addHandler(orig_handler)
theano.theano_logger.removeHandler(new_handler)
def test_pydotprint_long_name():
"""This is a REALLY PARTIAL TEST.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论