提交 c2efab91 authored 作者: carriepl's avatar carriepl

Merge pull request #2661 from thomasmesnard/pep8-1

pep8 on printing.py
......@@ -97,7 +97,8 @@ def debugprint(obj, depth=-1, print_type=False,
profile_list.extend([None for item in obj.outputs])
elif isinstance(obj, Function):
results_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
profile_list.extend(
[obj.profile for item in obj.maker.fgraph.outputs])
order = obj.maker.fgraph.toposort()
elif isinstance(obj, gof.FunctionGraph):
results_to_print.extend(obj.outputs)
......@@ -119,7 +120,7 @@ def debugprint(obj, depth=-1, print_type=False,
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r)
if p != None:
if p is not None:
print >> file, """
Timing Info
-----------
......@@ -264,6 +265,7 @@ class PrinterState(gof.utils.scratchpad):
props = {}
return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter:
def __init__(self, operator, precedence, assoc='left'):
......@@ -279,13 +281,13 @@ class OperatorPrinter:
raise TypeError("operator %s cannot represent a variable that is "
"not the result of an operation" % self.operator)
## Precedence seems to be buggy, see #249
## So, in doubt, we parenthesize everything.
#outer_precedence = getattr(pstate, 'precedence', -999999)
#outer_assoc = getattr(pstate, 'assoc', 'none')
#if outer_precedence > self.precedence:
# Precedence seems to be buggy, see #249
# So, in doubt, we parenthesize everything.
# outer_precedence = getattr(pstate, 'precedence', -999999)
# outer_assoc = getattr(pstate, 'assoc', 'none')
# if outer_precedence > self.precedence:
# parenthesize = True
#else:
# else:
# parenthesize = False
parenthesize = True
......@@ -368,7 +370,6 @@ class MemberPrinter:
if node is None:
raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
names = self.names
idx = node.outputs.index(output)
name = self.names[idx]
input = node.inputs[0]
......@@ -463,7 +464,8 @@ class PPrinter:
inv_updates = dict((b, a) for (a, b) in updates.iteritems())
i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
list(outputs) + updates.values()):
list(outputs) +
updates.values()):
for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
......@@ -475,8 +477,8 @@ class PPrinter:
name = 'out[%i]' % outputs.index(output)
else:
name = output.name
#backport
#name = 'out[%i]' % outputs.index(output) if output.name
# backport
# name = 'out[%i]' % outputs.index(output) if output.name
# is None else output.name
current = output
try:
......@@ -639,7 +641,7 @@ def pydotprint(fct, outfile=None,
mode = fct.maker.mode
profile = getattr(fct, "profile", None)
if (not isinstance(mode, ProfileMode)
or not fct in mode.profile_stats):
or fct not in mode.profile_stats):
mode = None
outputs = fct.maker.fgraph.outputs
topo = fct.maker.fgraph.toposort()
......@@ -860,7 +862,8 @@ def pydotprint(fct, outfile=None,
g.add_edge(pd.Edge(varstr, astr, label=label, **param))
else:
# no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner), astr, label=label, **param))
g.add_edge(pd.Edge(apply_name(var.owner), astr,
label=label, **param))
for id, var in enumerate(node.outputs):
varstr = var_name(var)
......@@ -952,8 +955,9 @@ def pydotprint_variables(vars,
try:
import pydot as pd
except ImportError:
print ("Failed to import pydot. You must install pydot for "
str = ("Failed to import pydot. You must install pydot for " +
"`pydotprint_variables` to work.")
print str
return
g = pd.Dot()
my_list = {}
......@@ -1192,9 +1196,11 @@ def min_informative_str(obj, indent_level=0,
elif hasattr(obj, 'owner') and obj.owner is not None:
name = str(obj.owner.op)
for ipt in obj.owner.inputs:
name += '\n' + min_informative_str(ipt,
name += '\n'
name += min_informative_str(ipt,
indent_level=indent_level + 1,
_prev_obs=_prev_obs, _tag_generator=_tag_generator)
_prev_obs=_prev_obs,
_tag_generator=_tag_generator)
else:
name = str(obj)
......@@ -1217,7 +1223,8 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
try:
import hashlib
except ImportError:
raise RuntimeError("Can't run var_descriptor because hashlib is not available.")
raise RuntimeError(
"Can't run var_descriptor because hashlib is not available.")
if _prev_obs is None:
_prev_obs = {}
......@@ -1239,13 +1246,15 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
# it can have different semantics depending on the strides
# of the ndarray
name = '<ndarray:'
name += 'strides=['+','.join(str(stride) for stride in obj.strides)+']'
name += ',digest='+hashlib.md5(obj).hexdigest()+'>'
name += 'strides=[' + ','.join(str(stride)
for stride in obj.strides) + ']'
name += ',digest=' + hashlib.md5(obj).hexdigest() + '>'
elif hasattr(obj, 'owner') and obj.owner is not None:
name = str(obj.owner.op) + '('
name += ','.join(var_descriptor(ipt,
_prev_obs=_prev_obs, _tag_generator=_tag_generator) for ipt
in obj.owner.inputs)
_prev_obs=_prev_obs,
_tag_generator=_tag_generator)
for ipt in obj.owner.inputs)
name += ')'
elif hasattr(obj, 'name') and obj.name is not None:
# Only print the name if there is no owner.
......@@ -1271,7 +1280,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
def position_independent_str(obj):
if isinstance(obj, theano.gof.graph.Variable):
rval = 'theano_var'
rval += '{type='+str(obj.type)+'}'
rval += '{type=' + str(obj.type) + '}'
else:
raise NotImplementedError()
......@@ -1288,13 +1297,15 @@ def hex_digest(x):
try:
import hashlib
except ImportError:
raise RuntimeError("Can't run hex_digest because hashlib is not available.")
raise RuntimeError("Can't run hex_digest"
"because hashlib is not available.")
assert isinstance(x, np.ndarray)
rval = hashlib.md5(x.tostring()).hexdigest()
# hex digest must be annotated with strides to avoid collisions
# because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged
# into a tensor
rval = rval + '|strides=[' + ','.join(str(stride) for stride in x.strides) + ']'
rval = rval + '|strides=[' + ','.join(str(stride)
for stride in x.strides) + ']'
rval = rval + '|shape=[' + ','.join(str(s) for s in x.shape) + ']'
return rval
......@@ -18,7 +18,6 @@ except ImportError:
whitelist_flake8 = [
"updates.py",
"printing.py",
"__init__.py",
"configparser.py",
"ifelse.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论