提交 66dc1578 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make printing.py pass test_flake8

上级 837f6ad6
...@@ -8,11 +8,9 @@ import logging ...@@ -8,11 +8,9 @@ import logging
import os import os
import sys import sys
import warnings import warnings
# Not available on all platforms import hashlib
hashlib = None
import numpy import numpy as np
np = numpy
try: try:
import pydot as pd import pydot as pd
...@@ -106,7 +104,7 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -106,7 +104,7 @@ def debugprint(obj, depth=-1, print_type=False,
results_to_print.extend(obj.outputs) results_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs]) profile_list.extend([None for item in obj.outputs])
order = obj.toposort() order = obj.toposort()
elif isinstance(obj, (int, long, float, numpy.ndarray)): elif isinstance(obj, (int, long, float, np.ndarray)):
print obj print obj
elif isinstance(obj, (theano.In, theano.Out)): elif isinstance(obj, (theano.In, theano.Out)):
results_to_print.append(obj.variable) results_to_print.append(obj.variable)
...@@ -296,8 +294,8 @@ class OperatorPrinter: ...@@ -296,8 +294,8 @@ class OperatorPrinter:
input_strings = [] input_strings = []
max_i = len(node.inputs) - 1 max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
if (self.assoc == 'left' and i != 0 or self.assoc == 'right' if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and
and i != max_i): i != max_i):
s = pprinter.process(input, pstate.clone( s = pprinter.process(input, pstate.clone(
precedence=self.precedence + 1e-6)) precedence=self.precedence + 1e-6))
else: else:
...@@ -334,8 +332,9 @@ class PatternPrinter: ...@@ -334,8 +332,9 @@ class PatternPrinter:
pattern, precedences = self.patterns[idx] pattern, precedences = self.patterns[idx]
precedences += (1000,) * len(node.inputs) precedences += (1000,) * len(node.inputs)
pp_process = lambda input, precedence: pprinter.process( def pp_process(input, precedence):
input, pstate.clone(precedence=precedence)) return pprinter.process(input, pstate.clone(precedence=precedence))
d = dict((str(i), x) d = dict((str(i), x)
for i, x in enumerate(pp_process(input, precedence) for i, x in enumerate(pp_process(input, precedence)
for input, precedence in for input, precedence in
...@@ -416,15 +415,14 @@ class LeafPrinter: ...@@ -416,15 +415,14 @@ class LeafPrinter:
class PPrinter: class PPrinter:
def __init__(self): def __init__(self):
self.printers = [] self.printers = []
def assign(self, condition, printer): def assign(self, condition, printer):
if isinstance(condition, gof.Op): if isinstance(condition, gof.Op):
op = condition op = condition
condition = (lambda pstate, r: r.owner is not None condition = (lambda pstate, r: r.owner is not None and
and r.owner.op == op) r.owner.op == op)
self.printers.insert(0, (condition, printer)) self.printers.insert(0, (condition, printer))
def process(self, r, pstate=None): def process(self, r, pstate=None):
...@@ -460,14 +458,12 @@ class PPrinter: ...@@ -460,14 +458,12 @@ class PPrinter:
map(str, list(inputs) + updates.keys())))] map(str, list(inputs) + updates.keys())))]
else: else:
strings = [] strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None pprinter = self.clone_assign(lambda pstate, r: r.name is not None and
and r is not current, r is not current, LeafPrinter())
LeafPrinter())
inv_updates = dict((b, a) for (a, b) in updates.iteritems()) inv_updates = dict((b, a) for (a, b) in updates.iteritems())
i = 1 i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(), for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
list(outputs) + list(outputs) + updates.values()):
updates.values()):
for output in node.outputs: for output in node.outputs:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
...@@ -532,8 +528,8 @@ else: ...@@ -532,8 +528,8 @@ else:
pprint = PPrinter() pprint = PPrinter()
pprint.assign(lambda pstate, r: True, DefaultPrinter()) pprint.assign(lambda pstate, r: True, DefaultPrinter())
pprint.assign(lambda pstate, r: hasattr(pstate, 'target') pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and
and pstate.target is not r and r.name is not None, pstate.target is not r and r.name is not None,
LeafPrinter()) LeafPrinter())
pp = pprint pp = pprint
...@@ -643,8 +639,8 @@ def pydotprint(fct, outfile=None, ...@@ -643,8 +639,8 @@ def pydotprint(fct, outfile=None,
if isinstance(fct, Function): if isinstance(fct, Function):
mode = fct.maker.mode mode = fct.maker.mode
profile = getattr(fct, "profile", None) profile = getattr(fct, "profile", None)
if (not isinstance(mode, ProfileMode) if (not isinstance(mode, ProfileMode) or
or fct not in mode.profile_stats): fct not in mode.profile_stats):
mode = None mode = None
outputs = fct.maker.fgraph.outputs outputs = fct.maker.fgraph.outputs
topo = fct.maker.fgraph.toposort() topo = fct.maker.fgraph.toposort()
...@@ -679,8 +675,8 @@ def pydotprint(fct, outfile=None, ...@@ -679,8 +675,8 @@ def pydotprint(fct, outfile=None,
c3 = pd.Cluster('Middle') c3 = pd.Cluster('Middle')
cond = None cond = None
for node in topo: for node in topo:
if (node.op.__class__.__name__ == 'IfElse' if (node.op.__class__.__name__ == 'IfElse' and
and node.op.name == cond_highlight): node.op.name == cond_highlight):
cond = node cond = node
if cond is None: if cond is None:
_logger.warn("pydotprint: cond_highlight is set but there is no" _logger.warn("pydotprint: cond_highlight is set but there is no"
...@@ -719,17 +715,17 @@ def pydotprint(fct, outfile=None, ...@@ -719,17 +715,17 @@ def pydotprint(fct, outfile=None,
else: else:
varstr = 'name=' + var.name + " " + str(var.type) varstr = 'name=' + var.name + " " + str(var.type)
elif isinstance(var, gof.Constant): elif isinstance(var, gof.Constant):
dstr = 'val=' + str(numpy.asarray(var.data)) dstr = 'val=' + str(np.asarray(var.data))
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
varstr = '%s %s' % (dstr, str(var.type)) varstr = '%s %s' % (dstr, str(var.type))
elif (var in input_update elif (var in input_update and
and input_update[var].variable.name is not None): input_update[var].variable.name is not None):
if var_with_name_simple: if var_with_name_simple:
varstr = input_update[var].variable.name + " UPDATE" varstr = input_update[var].variable.name + " UPDATE"
else: else:
varstr = (input_update[var].variable.name + " UPDATE " varstr = (input_update[var].variable.name + " UPDATE " +
+ str(var.type)) str(var.type))
else: else:
# a var id is needed as otherwise var with the same type will be # a var id is needed as otherwise var with the same type will be
# merged in the graph. # merged in the graph.
...@@ -784,8 +780,8 @@ def pydotprint(fct, outfile=None, ...@@ -784,8 +780,8 @@ def pydotprint(fct, outfile=None,
if (applystr in all_strings) or with_ids: if (applystr in all_strings) or with_ids:
idx = ' id=' + str(topo.index(node)) idx = ' id=' + str(topo.index(node))
if len(applystr) + len(idx) > max_label_size: if len(applystr) + len(idx) > max_label_size:
applystr = (applystr[:max_label_size - 3 - len(idx)] + idx applystr = (applystr[:max_label_size - 3 - len(idx)] + idx +
+ '...') '...')
else: else:
applystr = applystr + idx applystr = applystr + idx
elif len(applystr) > max_label_size: elif len(applystr) > max_label_size:
...@@ -1220,15 +1216,6 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None): ...@@ -1220,15 +1216,6 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
how a variable is computed. Does not include any memory how a variable is computed. Does not include any memory
location dependent information such as the id of a node. location dependent information such as the id of a node.
""" """
global hashlib
if hashlib is None:
try:
import hashlib
except ImportError:
raise RuntimeError(
"Can't run var_descriptor because hashlib is not available.")
if _prev_obs is None: if _prev_obs is None:
_prev_obs = {} _prev_obs = {}
...@@ -1294,14 +1281,6 @@ def hex_digest(x): ...@@ -1294,14 +1281,6 @@ def hex_digest(x):
""" """
Returns a short, mostly hexadecimal hash of a numpy ndarray Returns a short, mostly hexadecimal hash of a numpy ndarray
""" """
global hashlib
if hashlib is None:
try:
import hashlib
except ImportError:
raise RuntimeError("Can't run hex_digest"
"because hashlib is not available.")
assert isinstance(x, np.ndarray) assert isinstance(x, np.ndarray)
rval = hashlib.md5(x.tostring()).hexdigest() rval = hashlib.md5(x.tostring()).hexdigest()
# hex digest must be annotated with strides to avoid collisions # hex digest must be annotated with strides to avoid collisions
......
...@@ -323,7 +323,6 @@ whitelist_flake8 = [ ...@@ -323,7 +323,6 @@ whitelist_flake8 = [
"gof/sandbox/equilibrium.py", "gof/sandbox/equilibrium.py",
"sandbox/cuda/opt_util.py", "sandbox/cuda/opt_util.py",
"gof/tests/test_utils.py", "gof/tests/test_utils.py",
"printing.py",
"raise_op.py", "raise_op.py",
"tests/test_flake8.py", "tests/test_flake8.py",
"misc/pkl_utils.py", "misc/pkl_utils.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论