提交 5dd4c64a authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #164 from nouiz/pep8

Pep8 I've looked over it and everything seems fine.
"""Pretty-printing (pprint()), the 'Print' Op, debugprint() and pydotprint(). """Pretty-printing (pprint()), the 'Print' Op, debugprint() and pydotprint().
They all allow different way to print a graph or the result of an Op in a graph(Print Op)
They all allow different way to print a graph or the result of an Op
in a graph(Print Op)
""" """
from copy import copy from copy import copy
import logging import logging
import sys, os, StringIO import os
import StringIO
import sys
import numpy import numpy
try:
import pydot as pd
pydot_imported = True
except ImportError:
pydot_imported = False
import theano import theano
import gof import gof
from theano import config from theano import config
...@@ -15,7 +25,8 @@ from theano.gof.python25 import any ...@@ -15,7 +25,8 @@ from theano.gof.python25 import any
from theano.compile import Function, debugmode from theano.compile import Function, debugmode
from theano.compile.profilemode import ProfileMode from theano.compile.profilemode import ProfileMode
_logger=logging.getLogger("theano.printing") _logger = logging.getLogger("theano.printing")
def debugprint(obj, depth=-1, print_type=False, file=None): def debugprint(obj, depth=-1, print_type=False, file=None):
"""Print a computation graph to file """Print a computation graph to file
...@@ -32,17 +43,19 @@ def debugprint(obj, depth=-1, print_type=False, file=None): ...@@ -32,17 +43,19 @@ def debugprint(obj, depth=-1, print_type=False, file=None):
:returns: string if `file` == 'str', else file arg :returns: string if `file` == 'str', else file arg
Each line printed represents a Variable in the graph. Each line printed represents a Variable in the graph.
The indentation of each line corresponds to its depth in the symbolic graph. The indentation of lines corresponds to its depth in the symbolic graph.
The first part of the text identifies whether it is an input (if a name or type is printed) The first part of the text identifies whether it is an input
or the output of some Apply (in which case the Op is printed). (if a name or type is printed) or the output of some Apply (in which case
the Op is printed).
The second part of the text is the memory location of the Variable. The second part of the text is the memory location of the Variable.
If print_type is True, there is a third part, containing the type of the Variable If print_type is True, we add a part containing the type of the Variable
If a Variable is encountered multiple times in the depth-first search, it is only printed If a Variable is encountered multiple times in the depth-first search,
recursively the first time. Later, just the Variable and its memory location are printed. it is only printed recursively the first time. Later, just the Variable
and its memory location are printed.
If an Apply has multiple outputs, then a '.N' suffix will be appended to the Apply's If an Apply has multiple outputs, then a '.N' suffix will be appended
identifier, to indicate which output a line corresponds to. to the Apply's identifier, to indicate which output a line corresponds to.
""" """
if file == 'str': if file == 'str':
...@@ -73,7 +86,7 @@ def debugprint(obj, depth=-1, print_type=False, file=None): ...@@ -73,7 +86,7 @@ def debugprint(obj, depth=-1, print_type=False, file=None):
file=_file, order=order) file=_file, order=order)
if file is _file: if file is _file:
return file return file
elif file=='str': elif file == 'str':
return _file.getvalue() return _file.getvalue()
else: else:
_file.flush() _file.flush()
...@@ -86,42 +99,50 @@ def _print_fn(op, xin): ...@@ -86,42 +99,50 @@ def _print_fn(op, xin):
pmsg = temp() pmsg = temp()
else: else:
pmsg = temp pmsg = temp
print op.message, attr,'=', pmsg print op.message, attr, '=', pmsg
class Print(Op): class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs """ This identity-like Op print as a side effect.
when it runs. Default behaviour is to print the __str__ representation. Optionally, one
can pass a list of the input member functions to execute, or attributes to print. This identity-like Op has the side effect of printing a message
followed by its inputs when it runs. Default behaviour is to print
the __str__ representation. Optionally, one can pass a list of the
input member functions to execute, or attributes to print.
@type message: String @type message: String
@param message: string to prepend to the output @param message: string to prepend to the output
@type attrs: list of Strings @type attrs: list of Strings
@param attrs: list of input node attributes or member functions to print. Functions are @param attrs: list of input node attributes or member functions to print.
identified through callable(), executed and their return value printed. Functions are identified through callable(), executed and
their return value printed.
:note: WARNING. This can disable some optimization(speed and stabilization)! :note: WARNING. This can disable some optimizations!
(speed and/pr stabilization)
""" """
view_map={0:[0]} view_map = {0: [0]}
def __init__(self,message="", attrs=("__str__",), global_fn=_print_fn):
self.message=message
self.attrs=tuple(attrs) # attrs should be a hashable iterable
self.global_fn=global_fn
def make_node(self,xin): def __init__(self, message="", attrs=("__str__",), global_fn=_print_fn):
self.message = message
self.attrs = tuple(attrs) # attrs should be a hashable iterable
self.global_fn = global_fn
def make_node(self, xin):
xout = xin.type.make_variable() xout = xin.type.make_variable()
return Apply(op = self, inputs = [xin], outputs=[xout]) return Apply(op=self, inputs=[xin], outputs=[xout])
def perform(self,node,inputs,output_storage): def perform(self, node, inputs, output_storage):
xin, = inputs xin, = inputs
xout, = output_storage xout, = output_storage
xout[0] = xin xout[0] = xin
self.global_fn(self, xin) self.global_fn(self, xin)
def grad(self,input,output_gradients): def grad(self, input, output_gradients):
return output_gradients return output_gradients
def __eq__(self, other): def __eq__(self, other):
return type(self)==type(other) and self.message==other.message and self.attrs==other.attrs return (type(self) == type(other) and self.message == other.message
and self.attrs == other.attrs)
def __hash__(self): def __hash__(self):
return hash(self.message) ^ hash(self.attrs) return hash(self.message) ^ hash(self.attrs)
...@@ -133,22 +154,23 @@ class Print(Op): ...@@ -133,22 +154,23 @@ class Print(Op):
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
class PrinterState(gof.utils.scratchpad): class PrinterState(gof.utils.scratchpad):
def __init__(self, props = {}, **more_props): def __init__(self, props={}, **more_props):
if isinstance(props, gof.utils.scratchpad): if isinstance(props, gof.utils.scratchpad):
self.__update__(props) self.__update__(props)
else: else:
self.__dict__.update(props) self.__dict__.update(props)
self.__dict__.update(more_props) self.__dict__.update(more_props)
def clone(self, props = {}, **more_props): def clone(self, props={}, **more_props):
return PrinterState(self, **dict(props, **more_props)) return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter: class OperatorPrinter:
def __init__(self, operator, precedence, assoc = 'left'): def __init__(self, operator, precedence, assoc='left'):
self.operator = operator self.operator = operator
self.precedence = precedence self.precedence = precedence
self.assoc = assoc self.assoc = assoc
...@@ -157,7 +179,8 @@ class OperatorPrinter: ...@@ -157,7 +179,8 @@ class OperatorPrinter:
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError("operator %s cannot represent a variable that is not the result of an operation" % self.operator) raise TypeError("operator %s cannot represent a variable that is "
"not the result of an operation" % self.operator)
## Precedence seems to be buggy, see #249 ## Precedence seems to be buggy, see #249
## So, in doubt, we parenthesize everything. ## So, in doubt, we parenthesize everything.
...@@ -172,17 +195,22 @@ class OperatorPrinter: ...@@ -172,17 +195,22 @@ class OperatorPrinter:
input_strings = [] input_strings = []
max_i = len(node.inputs) - 1 max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
if self.assoc == 'left' and i != 0 or self.assoc == 'right' and i != max_i: if (self.assoc == 'left' and i != 0 or self.assoc == 'right'
s = pprinter.process(input, pstate.clone(precedence = self.precedence + 1e-6)) and i != max_i):
s = pprinter.process(input, pstate.clone(
precedence=self.precedence + 1e-6))
else: else:
s = pprinter.process(input, pstate.clone(precedence = self.precedence)) s = pprinter.process(input, pstate.clone(
precedence=self.precedence))
input_strings.append(s) input_strings.append(s)
if len(input_strings) == 1: if len(input_strings) == 1:
s = self.operator + input_strings[0] s = self.operator + input_strings[0]
else: else:
s = (" %s " % self.operator).join(input_strings) s = (" %s " % self.operator).join(input_strings)
if parenthesize: return "(%s)" % s if parenthesize:
else: return s return "(%s)" % s
else:
return s
class PatternPrinter: class PatternPrinter:
...@@ -199,13 +227,19 @@ class PatternPrinter: ...@@ -199,13 +227,19 @@ class PatternPrinter:
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError("Patterns %s cannot represent a variable that is not the result of an operation" % self.patterns) raise TypeError("Patterns %s cannot represent a variable that is "
"not the result of an operation" % self.patterns)
idx = node.outputs.index(output) idx = node.outputs.index(output)
pattern, precedences = self.patterns[idx] pattern, precedences = self.patterns[idx]
precedences += (1000,) * len(node.inputs) precedences += (1000,) * len(node.inputs)
return pattern % dict((str(i), x)
for i, x in enumerate(pprinter.process(input, pstate.clone(precedence = precedence)) pp_process = lambda input, precedence: pprinter.process(
for input, precedence in zip(node.inputs, precedences))) input, pstate.clone(precedence=precedence))
d = dict((str(i), x)
for i, x in enumerate(pp_process(input, precedence)
for input, precedence in
zip(node.inputs, precedences)))
return pattern % d
class FunctionPrinter: class FunctionPrinter:
...@@ -217,12 +251,15 @@ class FunctionPrinter: ...@@ -217,12 +251,15 @@ class FunctionPrinter:
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError("function %s cannot represent a variable that is not the result of an operation" % self.names) raise TypeError("function %s cannot represent a variable that is "
"not the result of an operation" % self.names)
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = -1000)) return "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs])) for input in node.inputs]))
class MemberPrinter: class MemberPrinter:
def __init__(self, *names): def __init__(self, *names):
...@@ -232,12 +269,15 @@ class MemberPrinter: ...@@ -232,12 +269,15 @@ class MemberPrinter:
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError("function %s cannot represent a variable that is not the result of an operation" % self.function) raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
names = self.names names = self.names
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
input = node.inputs[0] input = node.inputs[0]
return "%s.%s" % (pprinter.process(input, pstate.clone(precedence = 1000)), name) return "%s.%s" % (pprinter.process(input,
pstate.clone(precedence=1000)),
name)
class IgnorePrinter: class IgnorePrinter:
...@@ -246,7 +286,8 @@ class IgnorePrinter: ...@@ -246,7 +286,8 @@ class IgnorePrinter:
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError("function %s cannot represent a variable that is not the result of an operation" % self.function) raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
input = node.inputs[0] input = node.inputs[0]
return "%s" % pprinter.process(input, pstate) return "%s" % pprinter.process(input, pstate)
...@@ -261,9 +302,11 @@ class DefaultPrinter: ...@@ -261,9 +302,11 @@ class DefaultPrinter:
node = r.owner node = r.owner
if node is None: if node is None:
return LeafPrinter().process(r, pstate) return LeafPrinter().process(r, pstate)
return "%s(%s)" % (str(node.op), ", ".join([pprinter.process(input, pstate.clone(precedence = -1000)) return "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs])) for input in node.inputs]))
class LeafPrinter: class LeafPrinter:
def process(self, r, pstate): def process(self, r, pstate):
if r.name in greek: if r.name in greek:
...@@ -280,14 +323,15 @@ class PPrinter: ...@@ -280,14 +323,15 @@ class PPrinter:
def assign(self, condition, printer): def assign(self, condition, printer):
if isinstance(condition, gof.Op): if isinstance(condition, gof.Op):
op = condition op = condition
condition = lambda pstate, r: r.owner is not None and r.owner.op == op condition = (lambda pstate, r: r.owner is not None
and r.owner.op == op)
self.printers.insert(0, (condition, printer)) self.printers.insert(0, (condition, printer))
def process(self, r, pstate = None): def process(self, r, pstate=None):
if pstate is None: if pstate is None:
pstate = PrinterState(pprinter = self) pstate = PrinterState(pprinter=self)
elif isinstance(pstate, dict): elif isinstance(pstate, dict):
pstate = PrinterState(pprinter = self, **pstate) pstate = PrinterState(pprinter=self, **pstate)
for condition, printer in self.printers: for condition, printer in self.printers:
if condition(pstate, r): if condition(pstate, r):
return printer.process(r, pstate) return printer.process(r, pstate)
...@@ -302,15 +346,20 @@ class PPrinter: ...@@ -302,15 +346,20 @@ class PPrinter:
cp.assign(condition, printer) cp.assign(condition, printer)
return cp return cp
def process_graph(self, inputs, outputs, updates = {}, display_inputs = False): def process_graph(self, inputs, outputs, updates={},
if not isinstance(inputs, (list, tuple)): inputs = [inputs] display_inputs=False):
if not isinstance(outputs, (list, tuple)): outputs = [outputs] if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
current = None current = None
if display_inputs: if display_inputs:
strings = [(0, "inputs: " + ", ".join(map(str, list(inputs) + updates.keys())))] strings = [(0, "inputs: " + ", ".join(
map(str, list(inputs) + updates.keys())))]
else: else:
strings = [] strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and r is not current, pprinter = self.clone_assign(lambda pstate, r: r.name is not None
and r is not current,
LeafPrinter()) LeafPrinter())
inv_updates = dict((b, a) for (a, b) in updates.iteritems()) inv_updates = dict((b, a) for (a, b) in updates.iteritems())
i = 1 i = 1
...@@ -319,7 +368,8 @@ class PPrinter: ...@@ -319,7 +368,8 @@ class PPrinter:
for output in node.outputs: for output in node.outputs:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
strings.append((i + 1000, "%s <- %s" % (name, pprinter.process(output)))) strings.append((i + 1000, "%s <- %s" % (
name, pprinter.process(output))))
i += 1 i += 1
if output.name is not None or output in outputs: if output.name is not None or output in outputs:
if output.name is None: if output.name is None:
...@@ -327,16 +377,19 @@ class PPrinter: ...@@ -327,16 +377,19 @@ class PPrinter:
else: else:
name = output.name name = output.name
#backport #backport
#name = 'out[%i]' % outputs.index(output) if output.name is None else output.name #name = 'out[%i]' % outputs.index(output) if output.name
# is None else output.name
current = output current = output
try: try:
idx = 2000 + outputs.index(output) idx = 2000 + outputs.index(output)
except ValueError: except ValueError:
idx = i idx = i
if len(outputs) == 1 and outputs[0] is output: if len(outputs) == 1 and outputs[0] is output:
strings.append((idx, "return %s" % pprinter.process(output))) strings.append((idx, "return %s" %
pprinter.process(output)))
else: else:
strings.append((idx, "%s = %s" % (name, pprinter.process(output)))) strings.append((idx, "%s = %s" %
(name, pprinter.process(output))))
i += 1 i += 1
strings.sort() strings.sort()
return "\n".join(s[1] for s in strings) return "\n".join(s[1] for s in strings)
...@@ -354,29 +407,30 @@ class PPrinter: ...@@ -354,29 +407,30 @@ class PPrinter:
use_ascii = True use_ascii = True
if use_ascii: if use_ascii:
special = dict(middle_dot = "\\dot", special = dict(middle_dot="\\dot",
big_sigma = "\\Sigma") big_sigma="\\Sigma")
greek = dict(alpha = "\\alpha", greek = dict(alpha="\\alpha",
beta = "\\beta", beta="\\beta",
gamma = "\\gamma", gamma="\\gamma",
delta = "\\delta", delta="\\delta",
epsilon = "\\epsilon") epsilon="\\epsilon")
else: else:
special = dict(middle_dot = u"\u00B7", special = dict(middle_dot=u"\u00B7",
big_sigma = u"\u03A3") big_sigma=u"\u03A3")
greek = dict(alpha = u"\u03B1", greek = dict(alpha=u"\u03B1",
beta = u"\u03B2", beta=u"\u03B2",
gamma = u"\u03B3", gamma=u"\u03B3",
delta = u"\u03B4", delta=u"\u03B4",
epsilon = u"\u03B5") epsilon=u"\u03B5")
pprint = PPrinter() pprint = PPrinter()
pprint.assign(lambda pstate, r: True, DefaultPrinter()) pprint.assign(lambda pstate, r: True, DefaultPrinter())
pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is not r and r.name is not None, pprint.assign(lambda pstate, r: hasattr(pstate, 'target')
and pstate.target is not r and r.name is not None,
LeafPrinter()) LeafPrinter())
pp = pprint pp = pprint
...@@ -384,14 +438,15 @@ pp = pprint ...@@ -384,14 +438,15 @@ pp = pprint
# colors not used: orange, amber#FFBF00, purple, pink, # colors not used: orange, amber#FFBF00, purple, pink,
# used by default: green, blue, grey, red # used by default: green, blue, grey, red
default_colorCodes = {'GpuFromHost' : 'red', default_colorCodes = {'GpuFromHost': 'red',
'HostFromGpu' : 'red', 'HostFromGpu': 'red',
'Scan' : 'yellow', 'Scan': 'yellow',
'Shape' : 'cyan', 'Shape': 'cyan',
'IfElse' : 'magenta', 'IfElse': 'magenta',
'Elemwise': '#FFAABB', 'Elemwise': '#FFAABB',
'Subtensor': '#FFAAFF'} 'Subtensor': '#FFAAFF'}
def pydotprint(fct, outfile=None, def pydotprint(fct, outfile=None,
compact=True, format='png', with_ids=False, compact=True, format='png', with_ids=False,
high_contrast=True, cond_highlight=None, colorCodes=None, high_contrast=True, cond_highlight=None, colorCodes=None,
...@@ -423,7 +478,7 @@ def pydotprint(fct, outfile=None, ...@@ -423,7 +478,7 @@ def pydotprint(fct, outfile=None,
in files with the same name as the name given for the main in files with the same name as the name given for the main
file to which the name of the scan op is concatenated and file to which the name of the scan op is concatenated and
the index in the toposort of the scan. the index in the toposort of the scan.
This index can be printed in the graph with the option with_ids. This index can be printed with the option with_ids.
:param var_with_name_simple: If true and a variable have a name, :param var_with_name_simple: If true and a variable have a name,
we will print only the variable name. we will print only the variable name.
Otherwise, we concatenate the type to the var name. Otherwise, we concatenate the type to the var name.
...@@ -439,23 +494,25 @@ def pydotprint(fct, outfile=None, ...@@ -439,23 +494,25 @@ def pydotprint(fct, outfile=None,
label each edge between an input and the Apply node with the label each edge between an input and the Apply node with the
input's index. input's index.
green boxes are inputs to the graph green boxes are inputs variables to the graph
blue boxes are outputs of the graph blue boxes are outputs variables of the graph
grey boxes are vars generated by the graph that are not outputs and are not used grey boxes are variables that are not outputs and are not used
red ellipses are transfers from/to the gpu (ops with names GpuFromHost, HostFromGpu) red ellipses are transfers from/to the gpu (ops with names GpuFromHost,
HostFromGpu)
""" """
if colorCodes is None: if colorCodes is None:
colorCodes = default_colorCodes colorCodes = default_colorCodes
if outfile is None: if outfile is None:
outfile = os.path.join(config.compiledir,'theano.pydotprint.' + outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format) config.device + '.' + format)
if isinstance(fct, Function): if isinstance(fct, Function):
mode = fct.maker.mode mode = fct.maker.mode
fct_env = fct.maker.env fct_env = fct.maker.env
if not isinstance(mode,ProfileMode) or not mode.profile_stats.has_key(fct): if (not isinstance(mode, ProfileMode)
mode=None or not fct in mode.profile_stats):
mode = None
elif isinstance(fct, gof.Env): elif isinstance(fct, gof.Env):
mode = None mode = None
fct_env = fct fct_env = fct
...@@ -463,28 +520,28 @@ def pydotprint(fct, outfile=None, ...@@ -463,28 +520,28 @@ def pydotprint(fct, outfile=None,
raise ValueError(('pydotprint expects as input a theano.function or ' raise ValueError(('pydotprint expects as input a theano.function or '
'the env of a function!'), fct) 'the env of a function!'), fct)
try: if not pydot_imported:
import pydot as pd raise RuntimeError("Failed to import pydot. You must install pydot"
except ImportError: " for `pydotprint` to work.")
print ("Failed to import pydot. You must install pydot for "
"`pydotprint` to work.")
return return
g=pd.Dot() g = pd.Dot()
if cond_highlight is not None: if cond_highlight is not None:
c1 = pd.Cluster('Left') c1 = pd.Cluster('Left')
c2 = pd.Cluster('Right') c2 = pd.Cluster('Right')
c3 = pd.Cluster('Middle') c3 = pd.Cluster('Middle')
cond = None cond = None
for node in fct_env.toposort(): for node in fct_env.toposort():
if node.op.__class__.__name__=='IfElse' and node.op.name == cond_highlight: if (node.op.__class__.__name__ == 'IfElse'
and node.op.name == cond_highlight):
cond = node cond = node
if cond is None: if cond is None:
_logger.warn("pydotprint: cond_highlight is set but there is no IfElse node in the graph") _logger.warn("pydotprint: cond_highlight is set but there is no"
" IfElse node in the graph")
cond_highlight = None cond_highlight = None
if cond_highlight is not None: if cond_highlight is not None:
def recursive_pass(x,ls): def recursive_pass(x, ls):
if not x.owner: if not x.owner:
return ls return ls
else: else:
...@@ -493,8 +550,8 @@ def pydotprint(fct, outfile=None, ...@@ -493,8 +550,8 @@ def pydotprint(fct, outfile=None,
ls += recursive_pass(inp, ls) ls += recursive_pass(inp, ls)
return ls return ls
left = set(recursive_pass(cond.inputs[1],[])) left = set(recursive_pass(cond.inputs[1], []))
right =set(recursive_pass(cond.inputs[2],[])) right = set(recursive_pass(cond.inputs[2], []))
middle = left.intersection(right) middle = left.intersection(right)
left = left.difference(middle) left = left.difference(middle)
right = right.difference(middle) right = right.difference(middle)
...@@ -502,10 +559,9 @@ def pydotprint(fct, outfile=None, ...@@ -502,10 +559,9 @@ def pydotprint(fct, outfile=None,
left = list(left) left = list(left)
right = list(right) right = list(right)
var_str={} var_str = {}
all_strings = set() all_strings = set()
def var_name(var): def var_name(var):
if var in var_str: if var in var_str:
return var_str[var] return var_str[var]
...@@ -514,75 +570,82 @@ def pydotprint(fct, outfile=None, ...@@ -514,75 +570,82 @@ def pydotprint(fct, outfile=None,
if var_with_name_simple: if var_with_name_simple:
varstr = var.name varstr = var.name
else: else:
varstr = 'name='+var.name+" "+str(var.type) varstr = 'name=' + var.name + " " + str(var.type)
elif isinstance(var,gof.Constant): elif isinstance(var, gof.Constant):
dstr = 'val='+str(numpy.asarray(var.data)) dstr = 'val=' + str(numpy.asarray(var.data))
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
varstr = '%s %s'% (dstr, str(var.type)) varstr = '%s %s' % (dstr, str(var.type))
elif var in input_update and input_update[var].variable.name is not None: elif (var in input_update
and input_update[var].variable.name is not None):
if var_with_name_simple: if var_with_name_simple:
varstr = input_update[var].variable.name+" UPDATE" varstr = input_update[var].variable.name + " UPDATE"
else: else:
varstr = input_update[var].variable.name+" UPDATE "+str(var.type) varstr = (input_update[var].variable.name + " UPDATE "
+ str(var.type))
else: else:
#a var id is needed as otherwise var with the same type will be merged in the graph. #a var id is needed as otherwise var with the same type will be
#merged in the graph.
varstr = str(var.type) varstr = str(var.type)
if (varstr in all_strings) or with_ids: if (varstr in all_strings) or with_ids:
idx = ' id=' + str(len(var_str)) idx = ' id=' + str(len(var_str))
if len(varstr)+len(idx) > max_label_size: if len(varstr) + len(idx) > max_label_size:
varstr = varstr[:max_label_size-3-len(idx)]+idx+'...' varstr = varstr[:max_label_size - 3 - len(idx)] + idx + '...'
else: else:
varstr = varstr + idx varstr = varstr + idx
elif len(varstr) > max_label_size: elif len(varstr) > max_label_size:
varstr = varstr[:max_label_size-3]+'...' varstr = varstr[:max_label_size - 3] + '...'
var_str[var]=varstr var_str[var] = varstr
all_strings.add(varstr) all_strings.add(varstr)
return varstr return varstr
topo = fct_env.toposort() topo = fct_env.toposort()
apply_name_cache = {} apply_name_cache = {}
def apply_name(node): def apply_name(node):
if node in apply_name_cache: if node in apply_name_cache:
return apply_name_cache[node] return apply_name_cache[node]
prof_str='' prof_str = ''
if mode: if mode:
time = mode.profile_stats[fct].apply_time.get(node,0) time = mode.profile_stats[fct].apply_time.get(node, 0)
#second, % total time in profiler, %fct time in profiler #second, % total time in profiler, %fct time in profiler
if mode.local_time==0: if mode.local_time == 0:
pt=0 pt = 0
else: pt=time*100/mode.local_time else:
if mode.profile_stats[fct].fct_callcount==0: pt = time * 100 / mode.local_time
pf=0 if mode.profile_stats[fct].fct_callcount == 0:
else: pf = time*100/mode.profile_stats[fct].fct_call_time pf = 0
prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf) else:
applystr = str(node.op).replace(':','_') pf = time * 100 / mode.profile_stats[fct].fct_call_time
prof_str = ' (%.3fs,%.3f%%,%.3f%%)' % (time, pt, pf)
applystr = str(node.op).replace(':', '_')
applystr += prof_str applystr += prof_str
if (applystr in all_strings) or with_ids: if (applystr in all_strings) or with_ids:
idx = ' id='+str(topo.index(node)) idx = ' id=' + str(topo.index(node))
if len(applystr)+len(idx) > max_label_size: if len(applystr) + len(idx) > max_label_size:
applystr = applystr[:max_label_size-3-len(idx)]+idx+'...' applystr = (applystr[:max_label_size - 3 - len(idx)] + idx
+ '...')
else: else:
applystr = applystr + idx applystr = applystr + idx
elif len(applystr) > max_label_size: elif len(applystr) > max_label_size:
applystr = applystr[:max_label_size-3]+'...' applystr = applystr[:max_label_size - 3] + '...'
all_strings.add(applystr) all_strings.add(applystr)
apply_name_cache[node] = applystr apply_name_cache[node] = applystr
return applystr return applystr
# Update the inputs that have an update function # Update the inputs that have an update function
input_update={} input_update = {}
outputs = list(fct_env.outputs) outputs = list(fct_env.outputs)
if isinstance(fct, Function): if isinstance(fct, Function):
for i in reversed(fct.maker.expanded_inputs): for i in reversed(fct.maker.expanded_inputs):
if i.update is not None: if i.update is not None:
input_update[outputs.pop()] = i input_update[outputs.pop()] = i
apply_shape='ellipse' apply_shape = 'ellipse'
var_shape='box' var_shape = 'box'
for node_idx,node in enumerate(topo): for node_idx, node in enumerate(topo):
astr=apply_name(node) astr = apply_name(node)
use_color = None use_color = None
for opName, color in colorCodes.items(): for opName, color in colorCodes.items():
...@@ -593,9 +656,9 @@ def pydotprint(fct, outfile=None, ...@@ -593,9 +656,9 @@ def pydotprint(fct, outfile=None,
nw_node = pd.Node(astr, shape=apply_shape) nw_node = pd.Node(astr, shape=apply_shape)
elif high_contrast: elif high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color, nw_node = pd.Node(astr, style='filled', fillcolor=use_color,
shape = apply_shape) shape=apply_shape)
else: else:
nw_node = pd.Node(astr,color=use_color, shape = apply_shape) nw_node = pd.Node(astr, color=use_color, shape=apply_shape)
g.add_node(nw_node) g.add_node(nw_node)
if cond_highlight: if cond_highlight:
if node in middle: if node in middle:
...@@ -605,51 +668,50 @@ def pydotprint(fct, outfile=None, ...@@ -605,51 +668,50 @@ def pydotprint(fct, outfile=None,
elif node in right: elif node in right:
c2.add_node(nw_node) c2.add_node(nw_node)
for id, var in enumerate(node.inputs):
for id,var in enumerate(node.inputs): varstr = var_name(var)
varstr=var_name(var) label = str(var.type)
label=str(var.type) if len(label) > max_label_size:
if len(label)>max_label_size: label = label[:max_label_size - 3] + '...'
label = label[:max_label_size-3]+'...' if len(node.inputs) > 1:
if len(node.inputs)>1: label = str(id) + ' ' + label
label=str(id)+' '+label
if var.owner is None: if var.owner is None:
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr g.add_node(pd.Node(varstr,
,style = 'filled' style='filled',
, fillcolor='green',shape=var_shape)) fillcolor='green',
shape=var_shape))
else: else:
g.add_node(pd.Node(varstr,color='green',shape=var_shape)) g.add_node(pd.Node(varstr, color='green', shape=var_shape))
g.add_edge(pd.Edge(varstr,astr, label=label)) g.add_edge(pd.Edge(varstr, astr, label=label))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(varstr,astr, label=label)) g.add_edge(pd.Edge(varstr, astr, label=label))
else: else:
#no name, so we don't make a var ellipse #no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner),astr, label=label)) g.add_edge(pd.Edge(apply_name(var.owner), astr, label=label))
for id, var in enumerate(node.outputs):
for id,var in enumerate(node.outputs): varstr = var_name(var)
varstr=var_name(var) out = any([x[0] == 'output' for x in var.clients])
out = any([x[0]=='output' for x in var.clients]) label = str(var.type)
label=str(var.type) if len(node.outputs) > 1:
if len(node.outputs)>1: label = str(id) + ' ' + label
label=str(id)+' '+label if len(label) > max_label_size:
if len(label)>max_label_size: label = label[:max_label_size - 3] + '...'
label = label[:max_label_size-3]+'...'
if out: if out:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr,style='filled' g.add_node(pd.Node(varstr, style='filled',
,fillcolor='blue',shape=var_shape)) fillcolor='blue', shape=var_shape))
else: else:
g.add_node(pd.Node(varstr,color='blue',shape=var_shape)) g.add_node(pd.Node(varstr, color='blue', shape=var_shape))
elif len(var.clients)==0: elif len(var.clients) == 0:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr,style='filled', g.add_node(pd.Node(varstr, style='filled',
fillcolor='grey',shape=var_shape)) fillcolor='grey', shape=var_shape))
else: else:
g.add_node(pd.Node(varstr,color='grey',shape=var_shape)) g.add_node(pd.Node(varstr, color='grey', shape=var_shape))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
# else: # else:
...@@ -660,37 +722,32 @@ def pydotprint(fct, outfile=None, ...@@ -660,37 +722,32 @@ def pydotprint(fct, outfile=None,
g.add_subgraph(c2) g.add_subgraph(c2)
g.add_subgraph(c3) g.add_subgraph(c3)
if not outfile.endswith('.'+format): if not outfile.endswith('.' + format):
outfile+='.'+format outfile += '.' + format
g.write(outfile, prog='dot', format=format) g.write(outfile, prog='dot', format=format)
if print_output_file: if print_output_file:
print 'The output file is available at',outfile print 'The output file is available at', outfile
if scan_graphs: if scan_graphs:
scan_ops = [(idx, x) for idx,x in enumerate(fct_env.toposort()) if isinstance(x.op, theano.scan_module.scan_op.Scan)] scan_ops = [(idx, x) for idx, x in enumerate(fct_env.toposort())
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
path, fn = os.path.split(outfile) path, fn = os.path.split(outfile)
basename = '.'.join(fn.split('.')[:-1]) basename = '.'.join(fn.split('.')[:-1])
# Safe way of doing things .. a file name may contain multiple . # Safe way of doing things .. a file name may contain multiple .
ext = fn[len(basename):] ext = fn[len(basename):]
for idx, scan_op in scan_ops: for idx, scan_op in scan_ops:
# is there a chance that name is not defined? # is there a chance that name is not defined?
if hasattr(scan_op.op,'name'): if hasattr(scan_op.op, 'name'):
new_name = basename+'_'+scan_op.op.name+'_'+str(idx) new_name = basename + '_' + scan_op.op.name + '_' + str(idx)
else: else:
new_name = basename+'_'+str(idx) new_name = basename + '_' + str(idx)
new_name = os.path.join(path, new_name+ext) new_name = os.path.join(path, new_name + ext)
pydotprint(scan_op.op.fn, new_name, compact, format, with_ids, pydotprint(scan_op.op.fn, new_name, compact, format, with_ids,
high_contrast, cond_highlight, colorCodes, high_contrast, cond_highlight, colorCodes,
max_label_size, scan_graphs) max_label_size, scan_graphs)
def pydotprint_variables(vars, def pydotprint_variables(vars,
outfile=None, outfile=None,
format='png', format='png',
...@@ -704,7 +761,7 @@ def pydotprint_variables(vars, ...@@ -704,7 +761,7 @@ def pydotprint_variables(vars,
if colorCodes is None: if colorCodes is None:
colorCodes = default_colorCodes colorCodes = default_colorCodes
if outfile is None: if outfile is None:
outfile = os.path.join(config.compiledir,'theano.pydotprint.' + outfile = os.path.join(config.compiledir, 'theano.pydotprint.' +
config.device + '.' + format) config.device + '.' + format)
try: try:
import pydot as pd import pydot as pd
...@@ -712,12 +769,13 @@ def pydotprint_variables(vars, ...@@ -712,12 +769,13 @@ def pydotprint_variables(vars,
print ("Failed to import pydot. You must install pydot for " print ("Failed to import pydot. You must install pydot for "
"`pydotprint_variables` to work.") "`pydotprint_variables` to work.")
return return
g=pd.Dot() g = pd.Dot()
my_list = {} my_list = {}
orphanes = [] orphanes = []
if type(vars) not in (list,tuple): if type(vars) not in (list, tuple):
vars = [vars] vars = [vars]
var_str = {} var_str = {}
def var_name(var): def var_name(var):
if var in var_str: if var in var_str:
return var_str[var] return var_str[var]
...@@ -726,26 +784,27 @@ def pydotprint_variables(vars, ...@@ -726,26 +784,27 @@ def pydotprint_variables(vars,
if var_with_name_simple: if var_with_name_simple:
varstr = var.name varstr = var.name
else: else:
varstr = 'name='+var.name+" "+str(var.type) varstr = 'name=' + var.name + " " + str(var.type)
elif isinstance(var,gof.Constant): elif isinstance(var, gof.Constant):
dstr = 'val='+str(var.data) dstr = 'val=' + str(var.data)
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
varstr = '%s %s'% (dstr, str(var.type)) varstr = '%s %s' % (dstr, str(var.type))
else: else:
#a var id is needed as otherwise var with the same type will be merged in the graph. #a var id is needed as otherwise var with the same type will be
#merged in the graph.
varstr = str(var.type) varstr = str(var.type)
varstr += ' ' + str(len(var_str)) varstr += ' ' + str(len(var_str))
if len(varstr) > max_label_size: if len(varstr) > max_label_size:
varstr = varstr[:max_label_size-3]+'...' varstr = varstr[:max_label_size - 3] + '...'
var_str[var]=varstr var_str[var] = varstr
return varstr return varstr
def apply_name(node): def apply_name(node):
name = str(node.op).replace(':','_') name = str(node.op).replace(':', '_')
if len(name) > max_label_size: if len(name) > max_label_size:
name = name[:max_label_size-3]+'...' name = name[:max_label_size - 3] + '...'
return name return name
def plot_apply(app, d): def plot_apply(app, d):
...@@ -755,53 +814,52 @@ def pydotprint_variables(vars, ...@@ -755,53 +814,52 @@ def pydotprint_variables(vars,
return return
astr = apply_name(app) + '_' + str(len(my_list.keys())) astr = apply_name(app) + '_' + str(len(my_list.keys()))
if len(astr) > max_label_size: if len(astr) > max_label_size:
astr = astr[:max_label_size-3]+'...' astr = astr[:max_label_size - 3] + '...'
my_list[app] = astr my_list[app] = astr
use_color = None use_color = None
for opName, color in colorCodes.items(): for opName, color in colorCodes.items():
if opName in app.op.__class__.__name__ : if opName in app.op.__class__.__name__:
use_color = color use_color = color
if use_color is None: if use_color is None:
g.add_node(pd.Node(astr, shape='box')) g.add_node(pd.Node(astr, shape='box'))
elif high_contrast: elif high_contrast:
g.add_node(pd.Node(astr, style='filled', fillcolor=use_color, g.add_node(pd.Node(astr, style='filled', fillcolor=use_color,
shape = 'box')) shape='box'))
else: else:
g.add_node(pd.Nonde(astr,color=use_color, shape = 'box')) g.add_node(pd.Nonde(astr, color=use_color, shape='box'))
for i,nd in enumerate(app.inputs): for i, nd in enumerate(app.inputs):
if nd not in my_list: if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys())) varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size: if len(varastr) > max_label_size:
varastr = varastr[:max_label_size-3]+'...' varastr = varastr[:max_label_size - 3] + '...'
my_list[nd] = varastr my_list[nd] = varastr
if nd.owner is not None: if nd.owner is not None:
g.add_node(pd.Node(varastr)) g.add_node(pd.Node(varastr))
elif high_contrast: elif high_contrast:
g.add_node(pd.Node(varastr, style ='filled', g.add_node(pd.Node(varastr, style='filled',
fillcolor='green')) fillcolor='green'))
else: else:
g.add_node(pd.Node(varastr, color='green')) g.add_node(pd.Node(varastr, color='green'))
else: else:
varastr = my_list[nd] varastr = my_list[nd]
label = '' label = ''
if len(app.inputs)>1: if len(app.inputs) > 1:
label = str(i) label = str(i)
g.add_edge(pd.Edge(varastr, astr, label = label)) g.add_edge(pd.Edge(varastr, astr, label=label))
for i,nd in enumerate(app.outputs): for i, nd in enumerate(app.outputs):
if nd not in my_list: if nd not in my_list:
varastr = var_name(nd) + '_' + str(len(my_list.keys())) varastr = var_name(nd) + '_' + str(len(my_list.keys()))
if len(varastr) > max_label_size: if len(varastr) > max_label_size:
varastr = varastr[:max_label_size-3]+'...' varastr = varastr[:max_label_size - 3] + '...'
my_list[nd] = varastr my_list[nd] = varastr
color = None color = None
if nd in vars: if nd in vars:
color = 'blue' color = 'blue'
elif nd in orphanes : elif nd in orphanes:
color = 'gray' color = 'gray'
if color is None: if color is None:
g.add_node(pd.Node(varastr)) g.add_node(pd.Node(varastr))
...@@ -809,17 +867,16 @@ def pydotprint_variables(vars, ...@@ -809,17 +867,16 @@ def pydotprint_variables(vars,
g.add_node(pd.Node(varastr, style='filled', g.add_node(pd.Node(varastr, style='filled',
fillcolor=color)) fillcolor=color))
else: else:
g.add_node(pd.Node(varastr, color = color)) g.add_node(pd.Node(varastr, color=color))
else: else:
varastr = my_list[nd] varastr = my_list[nd]
label = '' label = ''
if len(app.outputs) > 1: if len(app.outputs) > 1:
label = str(i) label = str(i)
g.add_edge(pd.Edge(astr, varastr,label = label)) g.add_edge(pd.Edge(astr, varastr, label=label))
for nd in app.inputs: for nd in app.inputs:
if nd.owner: if nd.owner:
plot_apply(nd.owner, d-1) plot_apply(nd.owner, d - 1)
for nd in vars: for nd in vars:
if nd.owner: if nd.owner:
...@@ -833,8 +890,7 @@ def pydotprint_variables(vars, ...@@ -833,8 +890,7 @@ def pydotprint_variables(vars,
g.write_png(outfile, prog='dot') g.write_png(outfile, prog='dot')
print 'The output file is available at',outfile print 'The output file is available at', outfile
class _TagGenerator: class _TagGenerator:
...@@ -864,13 +920,15 @@ class _TagGenerator: ...@@ -864,13 +920,15 @@ class _TagGenerator:
while number != 0: while number != 0:
remainder = number % base remainder = number % base
new_char = chr(ord('A')+remainder) new_char = chr(ord('A') + remainder)
rval = new_char + rval rval = new_char + rval
number /= base number /= base
return rval return rval
def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator = None):
def min_informative_str(obj, indent_level=0,
_prev_obs=None, _tag_generator=None):
""" """
Returns a string specifying to the user what obj is Returns a string specifying to the user what obj is
The string will print out as much of the graph as is needed The string will print out as much of the graph as is needed
...@@ -890,15 +948,18 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator ...@@ -890,15 +948,18 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator
Basic design philosophy Basic design philosophy
----------------------- -----------------------
The idea behind this function is that it can be used as parts of command line tools
for debugging or for error messages. The information displayed is intended to be
concise and easily read by a human. In particular, it is intended to be informative
when working with large graphs composed of subgraphs from several different people's
code, as in pylearn2.
Stopping expanding subtrees when named variables are encountered makes it easier to The idea behind this function is that it can be used as parts of
understand what is happening when a graph formed by composing several different graphs command line tools for debugging or for error messages. The
made by code written by different authors has a bug. information displayed is intended to be concise and easily read by
a human. In particular, it is intended to be informative when
working with large graphs composed of subgraphs from several
different people's code, as in pylearn2.
Stopping expanding subtrees when named variables are encountered
makes it easier to understand what is happening when a graph
formed by composing several different graphs made by code written
by different authors has a bug.
An example output is: An example output is:
...@@ -907,17 +968,22 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator ...@@ -907,17 +968,22 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator
C. log_likelihood_h C. log_likelihood_h
If the user is told they have a problem computing this value, it's obvious that either If the user is told they have a problem computing this value, it's
log_likelihood_h or log_likelihood_v_given_h has the wrong dimensionality. The variable's obvious that either log_likelihood_h or log_likelihood_v_given_h
str object would only tell you that there was a problem with an Elemwise{add_no_inplace}. has the wrong dimensionality. The variable's str object would only
Since there are many such ops in a typical graph, such an error message is considerably tell you that there was a problem with an
less informative. Error messages based on this function should convey much more information Elemwise{add_no_inplace}. Since there are many such ops in a
about the location in the graph of the error while remaining succint. typical graph, such an error message is considerably less
informative. Error messages based on this function should convey
much more information about the location in the graph of the error
while remaining succint.
One final note: the use of capital letters to uniquely identify nodes within the graph One final note: the use of capital letters to uniquely identify
is motivated by legibility. I do not use numbers or lower case letters since these are nodes within the graph is motivated by legibility. I do not use
pretty common as parts of names of ops, etc. I also don't use the object's id like in numbers or lower case letters since these are pretty common as
debugprint because it gives such a long string that takes time to visually diff. parts of names of ops, etc. I also don't use the object's id like
in debugprint because it gives such a long string that takes time
to visually diff.
""" """
...@@ -926,7 +992,6 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator ...@@ -926,7 +992,6 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator
indent = '\t' * indent_level indent = '\t' * indent_level
if obj in _prev_obs: if obj in _prev_obs:
tag = _prev_obs[obj] tag = _prev_obs[obj]
...@@ -939,7 +1004,6 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator ...@@ -939,7 +1004,6 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator
_prev_obs[obj] = cur_tag _prev_obs[obj] = cur_tag
if hasattr(obj, '__array__'): if hasattr(obj, '__array__'):
name = '<ndarray>' name = '<ndarray>'
elif hasattr(obj, 'name') and obj.name is not None: elif hasattr(obj, 'name') and obj.name is not None:
...@@ -948,12 +1012,11 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator ...@@ -948,12 +1012,11 @@ def min_informative_str(obj, indent_level = 0, _prev_obs = None, _tag_generator
name = str(obj.owner.op) name = str(obj.owner.op)
for ipt in obj.owner.inputs: for ipt in obj.owner.inputs:
name += '\n' + min_informative_str(ipt, name += '\n' + min_informative_str(ipt,
indent_level = indent_level + 1, indent_level=indent_level + 1,
_prev_obs = _prev_obs, _tag_generator = _tag_generator) _prev_obs=_prev_obs, _tag_generator=_tag_generator)
else: else:
name = str(obj) name = str(obj)
prefix = cur_tag + '. ' prefix = cur_tag + '. '
rval = indent + prefix + name rval = indent + prefix + name
......
...@@ -22,14 +22,12 @@ def test_pydotprint_cond_highlight(): ...@@ -22,14 +22,12 @@ def test_pydotprint_cond_highlight():
""" """
# Skip test if pydot is not available. # Skip test if pydot is not available.
try: if not theano.printing.pydot_imported:
import pydot
except ImportError:
raise SkipTest('pydot not available') raise SkipTest('pydot not available')
x = tensor.dvector() x = tensor.dvector()
f = theano.function([x], x*2) f = theano.function([x], x * 2)
f([1,2,3,4]) f([1, 2, 3, 4])
s = StringIO.StringIO() s = StringIO.StringIO()
new_handler = logging.StreamHandler(s) new_handler = logging.StreamHandler(s)
...@@ -39,19 +37,21 @@ def test_pydotprint_cond_highlight(): ...@@ -39,19 +37,21 @@ def test_pydotprint_cond_highlight():
theano.theano_logger.removeHandler(orig_handler) theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler) theano.theano_logger.addHandler(new_handler)
try: try:
theano.printing.pydotprint(f, cond_highlight = True, print_output_file=False) theano.printing.pydotprint(f, cond_highlight=True,
print_output_file=False)
finally: finally:
theano.theano_logger.addHandler(orig_handler) theano.theano_logger.addHandler(orig_handler)
theano.theano_logger.removeHandler(new_handler) theano.theano_logger.removeHandler(new_handler)
assert s.getvalue() == 'pydotprint: cond_highlight is set but there is no IfElse node in the graph\n' assert (s.getvalue() == 'pydotprint: cond_highlight is set but there'
' is no IfElse node in the graph\n')
def test_pydotprint_profile(): def test_pydotprint_profile():
"""Just check that pydotprint does not crash with ProfileMode.""" """Just check that pydotprint does not crash with ProfileMode."""
A = tensor.matrix() A = tensor.matrix()
f = theano.function([A], A+1, mode='ProfileMode') f = theano.function([A], A + 1, mode='ProfileMode')
theano.printing.pydotprint(f, print_output_file=False) theano.printing.pydotprint(f, print_output_file=False)
...@@ -59,17 +59,17 @@ def test_min_informative_str(): ...@@ -59,17 +59,17 @@ def test_min_informative_str():
""" evaluates a reference output to make sure the """ evaluates a reference output to make sure the
min_informative_str function works as intended """ min_informative_str function works as intended """
A = tensor.matrix(name = 'A') A = tensor.matrix(name='A')
B = tensor.matrix(name = 'B') B = tensor.matrix(name='B')
C = A + B C = A + B
C.name = 'C' C.name = 'C'
D = tensor.matrix(name = 'D') D = tensor.matrix(name='D')
E = tensor.matrix(name = 'E') E = tensor.matrix(name='E')
F = D + E F = D + E
G = C + F G = C + F
mis = min_informative_str(G) mis = min_informative_str(G).replace("\t", " ")
reference = """A. Elemwise{add,no_inplace} reference = """A. Elemwise{add,no_inplace}
B. C B. C
...@@ -78,7 +78,7 @@ def test_min_informative_str(): ...@@ -78,7 +78,7 @@ def test_min_informative_str():
E. E""" E. E"""
if mis != reference: if mis != reference:
print '--'+mis+'--' print '--' + mis + '--'
print '--'+reference+'--' print '--' + reference + '--'
assert mis == reference assert mis == reference
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论