提交 c2efab91 authored 作者: carriepl's avatar carriepl

Merge pull request #2661 from thomasmesnard/pep8-1

pep8 on printing.py
...@@ -97,7 +97,8 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -97,7 +97,8 @@ def debugprint(obj, depth=-1, print_type=False,
profile_list.extend([None for item in obj.outputs]) profile_list.extend([None for item in obj.outputs])
elif isinstance(obj, Function): elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs) results_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs]) profile_list.extend(
[obj.profile for item in obj.maker.fgraph.outputs])
order = obj.maker.fgraph.toposort() order = obj.maker.fgraph.toposort()
elif isinstance(obj, gof.FunctionGraph): elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
...@@ -116,10 +117,10 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -116,10 +117,10 @@ def debugprint(obj, depth=-1, print_type=False,
for r, p in zip(results_to_print, profile_list): for r, p in zip(results_to_print, profile_list):
# Add the parent scan op to the list as well # Add the parent scan op to the list as well
if (hasattr(r.owner, 'op') and if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)): isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r) scan_ops.append(r)
if p != None: if p is not None:
print >> file, """ print >> file, """
Timing Info Timing Info
----------- -----------
...@@ -130,7 +131,7 @@ Timing Info ...@@ -130,7 +131,7 @@ Timing Info
<total time> time for this node + total times for this node's ancestors <total time> time for this node + total times for this node's ancestors
<% total time> total time for this node over total computation time <% total time> total time for this node over total computation time
N.B.: N.B.:
* Times include the node time and the function overhead. * Times include the node time and the function overhead.
* <total time> and <% total time> may over-count computation times * <total time> and <% total time> may over-count computation times
if inputs to a node share a common ancestor and should be viewed as a if inputs to a node share a common ancestor and should be viewed as a
...@@ -264,6 +265,7 @@ class PrinterState(gof.utils.scratchpad): ...@@ -264,6 +265,7 @@ class PrinterState(gof.utils.scratchpad):
props = {} props = {}
return PrinterState(self, **dict(props, **more_props)) return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter: class OperatorPrinter:
def __init__(self, operator, precedence, assoc='left'): def __init__(self, operator, precedence, assoc='left'):
...@@ -279,13 +281,13 @@ class OperatorPrinter: ...@@ -279,13 +281,13 @@ class OperatorPrinter:
raise TypeError("operator %s cannot represent a variable that is " raise TypeError("operator %s cannot represent a variable that is "
"not the result of an operation" % self.operator) "not the result of an operation" % self.operator)
## Precedence seems to be buggy, see #249 # Precedence seems to be buggy, see #249
## So, in doubt, we parenthesize everything. # So, in doubt, we parenthesize everything.
#outer_precedence = getattr(pstate, 'precedence', -999999) # outer_precedence = getattr(pstate, 'precedence', -999999)
#outer_assoc = getattr(pstate, 'assoc', 'none') # outer_assoc = getattr(pstate, 'assoc', 'none')
#if outer_precedence > self.precedence: # if outer_precedence > self.precedence:
# parenthesize = True # parenthesize = True
#else: # else:
# parenthesize = False # parenthesize = False
parenthesize = True parenthesize = True
...@@ -293,7 +295,7 @@ class OperatorPrinter: ...@@ -293,7 +295,7 @@ class OperatorPrinter:
max_i = len(node.inputs) - 1 max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
if (self.assoc == 'left' and i != 0 or self.assoc == 'right' if (self.assoc == 'left' and i != 0 or self.assoc == 'right'
and i != max_i): and i != max_i):
s = pprinter.process(input, pstate.clone( s = pprinter.process(input, pstate.clone(
precedence=self.precedence + 1e-6)) precedence=self.precedence + 1e-6))
else: else:
...@@ -368,7 +370,6 @@ class MemberPrinter: ...@@ -368,7 +370,6 @@ class MemberPrinter:
if node is None: if node is None:
raise TypeError("function %s cannot represent a variable that is" raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function) " not the result of an operation" % self.function)
names = self.names
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
input = node.inputs[0] input = node.inputs[0]
...@@ -463,7 +464,8 @@ class PPrinter: ...@@ -463,7 +464,8 @@ class PPrinter:
inv_updates = dict((b, a) for (a, b) in updates.iteritems()) inv_updates = dict((b, a) for (a, b) in updates.iteritems())
i = 1 i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(), for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
list(outputs) + updates.values()): list(outputs) +
updates.values()):
for output in node.outputs: for output in node.outputs:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
...@@ -475,8 +477,8 @@ class PPrinter: ...@@ -475,8 +477,8 @@ class PPrinter:
name = 'out[%i]' % outputs.index(output) name = 'out[%i]' % outputs.index(output)
else: else:
name = output.name name = output.name
#backport # backport
#name = 'out[%i]' % outputs.index(output) if output.name # name = 'out[%i]' % outputs.index(output) if output.name
# is None else output.name # is None else output.name
current = output current = output
try: try:
...@@ -639,8 +641,8 @@ def pydotprint(fct, outfile=None, ...@@ -639,8 +641,8 @@ def pydotprint(fct, outfile=None,
mode = fct.maker.mode mode = fct.maker.mode
profile = getattr(fct, "profile", None) profile = getattr(fct, "profile", None)
if (not isinstance(mode, ProfileMode) if (not isinstance(mode, ProfileMode)
or not fct in mode.profile_stats): or fct not in mode.profile_stats):
mode = None mode = None
outputs = fct.maker.fgraph.outputs outputs = fct.maker.fgraph.outputs
topo = fct.maker.fgraph.toposort() topo = fct.maker.fgraph.toposort()
elif isinstance(fct, gof.FunctionGraph): elif isinstance(fct, gof.FunctionGraph):
...@@ -675,7 +677,7 @@ def pydotprint(fct, outfile=None, ...@@ -675,7 +677,7 @@ def pydotprint(fct, outfile=None,
cond = None cond = None
for node in topo: for node in topo:
if (node.op.__class__.__name__ == 'IfElse' if (node.op.__class__.__name__ == 'IfElse'
and node.op.name == cond_highlight): and node.op.name == cond_highlight):
cond = node cond = node
if cond is None: if cond is None:
_logger.warn("pydotprint: cond_highlight is set but there is no" _logger.warn("pydotprint: cond_highlight is set but there is no"
...@@ -842,11 +844,11 @@ def pydotprint(fct, outfile=None, ...@@ -842,11 +844,11 @@ def pydotprint(fct, outfile=None,
label = label[:max_label_size - 3] + '...' label = label[:max_label_size - 3] + '...'
param = {} param = {}
if hasattr(node.op, 'view_map') and id in reduce( if hasattr(node.op, 'view_map') and id in reduce(
list.__add__, node.op.view_map.values(), []): list.__add__, node.op.view_map.values(), []):
param['color'] = 'blue' param['color'] = 'blue'
elif hasattr(node.op, 'destroy_map') and id in reduce( elif hasattr(node.op, 'destroy_map') and id in reduce(
list.__add__, node.op.destroy_map.values(), []): list.__add__, node.op.destroy_map.values(), []):
param['color'] = 'red' param['color'] = 'red'
if var.owner is None: if var.owner is None:
if high_contrast: if high_contrast:
g.add_node(pd.Node(varstr, g.add_node(pd.Node(varstr,
...@@ -860,7 +862,8 @@ def pydotprint(fct, outfile=None, ...@@ -860,7 +862,8 @@ def pydotprint(fct, outfile=None,
g.add_edge(pd.Edge(varstr, astr, label=label, **param)) g.add_edge(pd.Edge(varstr, astr, label=label, **param))
else: else:
# no name, so we don't make a var ellipse # no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner), astr, label=label, **param)) g.add_edge(pd.Edge(apply_name(var.owner), astr,
label=label, **param))
for id, var in enumerate(node.outputs): for id, var in enumerate(node.outputs):
varstr = var_name(var) varstr = var_name(var)
...@@ -952,8 +955,9 @@ def pydotprint_variables(vars, ...@@ -952,8 +955,9 @@ def pydotprint_variables(vars,
try: try:
import pydot as pd import pydot as pd
except ImportError: except ImportError:
print ("Failed to import pydot. You must install pydot for " str = ("Failed to import pydot. You must install pydot for " +
"`pydotprint_variables` to work.") "`pydotprint_variables` to work.")
print str
return return
g = pd.Dot() g = pd.Dot()
my_list = {} my_list = {}
...@@ -1192,9 +1196,11 @@ def min_informative_str(obj, indent_level=0, ...@@ -1192,9 +1196,11 @@ def min_informative_str(obj, indent_level=0,
elif hasattr(obj, 'owner') and obj.owner is not None: elif hasattr(obj, 'owner') and obj.owner is not None:
name = str(obj.owner.op) name = str(obj.owner.op)
for ipt in obj.owner.inputs: for ipt in obj.owner.inputs:
name += '\n' + min_informative_str(ipt, name += '\n'
indent_level=indent_level + 1, name += min_informative_str(ipt,
_prev_obs=_prev_obs, _tag_generator=_tag_generator) indent_level=indent_level + 1,
_prev_obs=_prev_obs,
_tag_generator=_tag_generator)
else: else:
name = str(obj) name = str(obj)
...@@ -1217,7 +1223,8 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None): ...@@ -1217,7 +1223,8 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
try: try:
import hashlib import hashlib
except ImportError: except ImportError:
raise RuntimeError("Can't run var_descriptor because hashlib is not available.") raise RuntimeError(
"Can't run var_descriptor because hashlib is not available.")
if _prev_obs is None: if _prev_obs is None:
_prev_obs = {} _prev_obs = {}
...@@ -1239,13 +1246,15 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None): ...@@ -1239,13 +1246,15 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
# it can have different semantics depending on the strides # it can have different semantics depending on the strides
# of the ndarray # of the ndarray
name = '<ndarray:' name = '<ndarray:'
name += 'strides=['+','.join(str(stride) for stride in obj.strides)+']' name += 'strides=[' + ','.join(str(stride)
name += ',digest='+hashlib.md5(obj).hexdigest()+'>' for stride in obj.strides) + ']'
name += ',digest=' + hashlib.md5(obj).hexdigest() + '>'
elif hasattr(obj, 'owner') and obj.owner is not None: elif hasattr(obj, 'owner') and obj.owner is not None:
name = str(obj.owner.op) + '(' name = str(obj.owner.op) + '('
name += ','.join(var_descriptor(ipt, name += ','.join(var_descriptor(ipt,
_prev_obs=_prev_obs, _tag_generator=_tag_generator) for ipt _prev_obs=_prev_obs,
in obj.owner.inputs) _tag_generator=_tag_generator)
for ipt in obj.owner.inputs)
name += ')' name += ')'
elif hasattr(obj, 'name') and obj.name is not None: elif hasattr(obj, 'name') and obj.name is not None:
# Only print the name if there is no owner. # Only print the name if there is no owner.
...@@ -1271,7 +1280,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None): ...@@ -1271,7 +1280,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
def position_independent_str(obj): def position_independent_str(obj):
if isinstance(obj, theano.gof.graph.Variable): if isinstance(obj, theano.gof.graph.Variable):
rval = 'theano_var' rval = 'theano_var'
rval += '{type='+str(obj.type)+'}' rval += '{type=' + str(obj.type) + '}'
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -1288,13 +1297,15 @@ def hex_digest(x): ...@@ -1288,13 +1297,15 @@ def hex_digest(x):
try: try:
import hashlib import hashlib
except ImportError: except ImportError:
raise RuntimeError("Can't run hex_digest because hashlib is not available.") raise RuntimeError("Can't run hex_digest"
"because hashlib is not available.")
assert isinstance(x, np.ndarray) assert isinstance(x, np.ndarray)
rval = hashlib.md5(x.tostring()).hexdigest() rval = hashlib.md5(x.tostring()).hexdigest()
# hex digest must be annotated with strides to avoid collisions # hex digest must be annotated with strides to avoid collisions
# because the buffer interface only exposes the raw data, not # because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged # any info about the semantics of how that data should be arranged
# into a tensor # into a tensor
rval = rval + '|strides=[' + ','.join(str(stride) for stride in x.strides) + ']' rval = rval + '|strides=[' + ','.join(str(stride)
for stride in x.strides) + ']'
rval = rval + '|shape=[' + ','.join(str(s) for s in x.shape) + ']' rval = rval + '|shape=[' + ','.join(str(s) for s in x.shape) + ']'
return rval return rval
...@@ -18,7 +18,6 @@ except ImportError: ...@@ -18,7 +18,6 @@ except ImportError:
whitelist_flake8 = [ whitelist_flake8 = [
"updates.py", "updates.py",
"printing.py",
"__init__.py", "__init__.py",
"configparser.py", "configparser.py",
"ifelse.py", "ifelse.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论