提交 66dc1578 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make printing.py pass test_flake8

上级 837f6ad6
......@@ -8,11 +8,9 @@ import logging
import os
import sys
import warnings
# Not available on all platforms
hashlib = None
import hashlib
import numpy
np = numpy
import numpy as np
try:
import pydot as pd
......@@ -106,7 +104,7 @@ def debugprint(obj, depth=-1, print_type=False,
results_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs])
order = obj.toposort()
elif isinstance(obj, (int, long, float, numpy.ndarray)):
elif isinstance(obj, (int, long, float, np.ndarray)):
print obj
elif isinstance(obj, (theano.In, theano.Out)):
results_to_print.append(obj.variable)
......@@ -296,8 +294,8 @@ class OperatorPrinter:
input_strings = []
max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs):
if (self.assoc == 'left' and i != 0 or self.assoc == 'right'
and i != max_i):
if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and
i != max_i):
s = pprinter.process(input, pstate.clone(
precedence=self.precedence + 1e-6))
else:
......@@ -334,8 +332,9 @@ class PatternPrinter:
pattern, precedences = self.patterns[idx]
precedences += (1000,) * len(node.inputs)
pp_process = lambda input, precedence: pprinter.process(
input, pstate.clone(precedence=precedence))
def pp_process(input, precedence):
return pprinter.process(input, pstate.clone(precedence=precedence))
d = dict((str(i), x)
for i, x in enumerate(pp_process(input, precedence)
for input, precedence in
......@@ -416,15 +415,14 @@ class LeafPrinter:
class PPrinter:
def __init__(self):
self.printers = []
def assign(self, condition, printer):
if isinstance(condition, gof.Op):
op = condition
condition = (lambda pstate, r: r.owner is not None
and r.owner.op == op)
condition = (lambda pstate, r: r.owner is not None and
r.owner.op == op)
self.printers.insert(0, (condition, printer))
def process(self, r, pstate=None):
......@@ -460,14 +458,12 @@ class PPrinter:
map(str, list(inputs) + updates.keys())))]
else:
strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None
and r is not current,
LeafPrinter())
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and
r is not current, LeafPrinter())
inv_updates = dict((b, a) for (a, b) in updates.iteritems())
i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
list(outputs) +
updates.values()):
list(outputs) + updates.values()):
for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
......@@ -532,8 +528,8 @@ else:
pprint = PPrinter()
pprint.assign(lambda pstate, r: True, DefaultPrinter())
pprint.assign(lambda pstate, r: hasattr(pstate, 'target')
and pstate.target is not r and r.name is not None,
pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and
pstate.target is not r and r.name is not None,
LeafPrinter())
pp = pprint
......@@ -643,8 +639,8 @@ def pydotprint(fct, outfile=None,
if isinstance(fct, Function):
mode = fct.maker.mode
profile = getattr(fct, "profile", None)
if (not isinstance(mode, ProfileMode)
or fct not in mode.profile_stats):
if (not isinstance(mode, ProfileMode) or
fct not in mode.profile_stats):
mode = None
outputs = fct.maker.fgraph.outputs
topo = fct.maker.fgraph.toposort()
......@@ -679,8 +675,8 @@ def pydotprint(fct, outfile=None,
c3 = pd.Cluster('Middle')
cond = None
for node in topo:
if (node.op.__class__.__name__ == 'IfElse'
and node.op.name == cond_highlight):
if (node.op.__class__.__name__ == 'IfElse' and
node.op.name == cond_highlight):
cond = node
if cond is None:
_logger.warn("pydotprint: cond_highlight is set but there is no"
......@@ -719,17 +715,17 @@ def pydotprint(fct, outfile=None,
else:
varstr = 'name=' + var.name + " " + str(var.type)
elif isinstance(var, gof.Constant):
dstr = 'val=' + str(numpy.asarray(var.data))
dstr = 'val=' + str(np.asarray(var.data))
if '\n' in dstr:
dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type))
elif (var in input_update
and input_update[var].variable.name is not None):
elif (var in input_update and
input_update[var].variable.name is not None):
if var_with_name_simple:
varstr = input_update[var].variable.name + " UPDATE"
else:
varstr = (input_update[var].variable.name + " UPDATE "
+ str(var.type))
varstr = (input_update[var].variable.name + " UPDATE " +
str(var.type))
else:
# a var id is needed as otherwise var with the same type will be
# merged in the graph.
......@@ -784,8 +780,8 @@ def pydotprint(fct, outfile=None,
if (applystr in all_strings) or with_ids:
idx = ' id=' + str(topo.index(node))
if len(applystr) + len(idx) > max_label_size:
applystr = (applystr[:max_label_size - 3 - len(idx)] + idx
+ '...')
applystr = (applystr[:max_label_size - 3 - len(idx)] + idx +
'...')
else:
applystr = applystr + idx
elif len(applystr) > max_label_size:
......@@ -1220,15 +1216,6 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
how a variable is computed. Does not include any memory
location dependent information such as the id of a node.
"""
global hashlib
if hashlib is None:
try:
import hashlib
except ImportError:
raise RuntimeError(
"Can't run var_descriptor because hashlib is not available.")
if _prev_obs is None:
_prev_obs = {}
......@@ -1294,14 +1281,6 @@ def hex_digest(x):
"""
Returns a short, mostly hexadecimal hash of a numpy ndarray
"""
global hashlib
if hashlib is None:
try:
import hashlib
except ImportError:
raise RuntimeError("Can't run hex_digest"
"because hashlib is not available.")
assert isinstance(x, np.ndarray)
rval = hashlib.md5(x.tostring()).hexdigest()
# hex digest must be annotated with strides to avoid collisions
......
......@@ -323,7 +323,6 @@ whitelist_flake8 = [
"gof/sandbox/equilibrium.py",
"sandbox/cuda/opt_util.py",
"gof/tests/test_utils.py",
"printing.py",
"raise_op.py",
"tests/test_flake8.py",
"misc/pkl_utils.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论