提交 cba9c812 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3130 from harlouci/flake8_gof

Flake8 gof
......@@ -177,7 +177,7 @@ def get_config_md5():
"""
all_opts = sorted([c for c in _config_var_list if c.in_c_key],
key=lambda cv: cv.fullname)
return theano.gof.cc.hash_from_code('\n'.join(
return theano.gof.utils.hash_from_code('\n'.join(
['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts]))
......
差异被折叠。
差异被折叠。
......@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
"""
# These are lists of Variable instances
inputs = fgraph.inputs
outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py
......@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
iset = set(inputs)
# IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys.
# This caused a slowdown.
......@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if
not isinstance(i, graph.Constant)
and not fgraph.destroyers(i)
and i not in protected_inputs]
not isinstance(i, graph.Constant) and
not fgraph.destroyers(i) and
i not in protected_inputs]
return inputs
if 0:
......@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
# Do the checking #
already_there = False
if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve"
......@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in"
" conflict with another plugin.")
####### end of checking ############
# end of checking #
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
......@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
# input_impact = set([input_root])
# add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
......@@ -390,7 +387,7 @@ if 0:
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
# if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
......@@ -421,7 +418,7 @@ if 0:
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app)
# UPDATE self.clients
......@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph.
pass
else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import")
# if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients
self.clients[old_r][app] -= 1
......@@ -529,9 +526,10 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
illegal_destroy = [
r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
# print 'destroying illegally'
raise InconsistencyError(
......@@ -603,7 +601,7 @@ if 0:
if input in root_impact \
and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
% (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
......@@ -621,7 +619,7 @@ if 0:
return rval
class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
......@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
# Do the checking #
already_there = False
if self.fgraph is fgraph:
already_there = True
......@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############
# Annotate the FunctionGraph #
self.unpickle(fgraph)
fgraph.destroy_handler = self
......@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag, 'indestructible', False) or \
isinstance(r, graph.Constant)]
illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
illegal_destroy)
raise InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies
for app in self.destroyers:
......@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
if idx0 == destroyed_idx)
# print 'tolerated', tolerated
# print 'ignored', ignored
for i, input in enumerate(app.inputs):
if i in ignored:
continue
if input in root_impact \
and (i not in tolerated or input is not destroyed_variable):
and (i not in tolerated or
input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
% (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
......
......@@ -13,7 +13,6 @@ from theano.gof import graph
from theano.gof import utils
from theano.gof import toolbox
from theano import config
import warnings
from theano.compat import OrderedDict
from six import iteritems, itervalues
......@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
NullType = None
class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
......@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
self.variable_locks = {}
self.profile = None
### Setup a Variable ###
# Setup a Variable #
def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
if getattr(r, 'cached', False):
......@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
" graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph.")
if (hasattr(r, 'fgraph') and
r.fgraph is not None and
r.fgraph is not self):
r.fgraph is not None and
r.fgraph is not self):
raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self
r.clients = []
#self.execute_callbacks('on_setup_variable', r)
# self.execute_callbacks('on_setup_variable', r)
def __setup_node__(self, node):
# sets up node so it belongs to this fgraph
......@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
str(node.op), str(node.op.destroy_map)))
node.fgraph = self
node.deps = {}
#self.execute_callbacks('on_setup_node', node)
# self.execute_callbacks('on_setup_node', node)
def disown(self):
""" WRITEME
......@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
self.inputs = None
self.outputs = None
### clients ###
# clients #
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
......@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
if set(r.clients).intersection(set(new_clients)):
print('ERROR: clients intersect!', file=sys.stderr)
print(' RCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in r.clients], file=sys.stderr)
for n, i in r.clients], file=sys.stderr)
print(' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients], file=sys.stderr)
for n, i in new_clients], file=sys.stderr)
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
......@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
return True
return False
### import ###
# import #
def __import_r__(self, variable, reason):
global NullType
if NullType is None:
......@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
if hasattr(r, 'fgraph') and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r)
if (r.owner is None and
not isinstance(r, graph.Constant) and
r not in self.inputs):
not isinstance(r, graph.Constant) and
r not in self.inputs):
# Verbose error message
# Show a complete chain of variables from the missing input to an output
if config.exception_verbosity == 'high':
......@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
assert node.fgraph is self
self.execute_callbacks('on_import', node, reason)
### prune ###
# prune #
def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
......@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs)
### change input ###
# change input #
def change_input(self, node, i, new_r, reason=None):
"""WRITEME
Changes node.inputs[i] to new_r.
......@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
if prune:
self.__prune_r__(r, reason=reason)
### replace ###
# replace #
def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
......@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
if detach is not None:
detach(self)
### callback utils ###
# callback utils #
def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME
Calls
......@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args)
return d
### misc ###
# misc #
def toposort(self):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
......@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
missing, excess)
for variable in variables:
if (variable.owner is None and
variable not in self.inputs and
not isinstance(variable, graph.Constant)):
variable not in self.inputs and
not isinstance(variable, graph.Constant)):
raise Exception("Undeclared input.", variable)
if variable.fgraph is not self:
raise Exception("Variable should belong to the FunctionGraph.",
......@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
def __repr__(self):
return self.__str__()
### clone ###
# clone #
def clone(self, check_integrity=True):
"""WRITEME"""
return self.clone_get_equiv(check_integrity)[0]
......
......@@ -7,14 +7,14 @@ import traceback
import numpy
import theano
from theano.compat import PY3, izip
from theano.compat import izip
from six import reraise
from six.moves import StringIO
from theano.gof import utils
from theano.gof import graph
from theano.gof.type import Type
from .utils import MethodNotDefined, undef
from .utils import undef
__excepthook = sys.excepthook
......@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else:
detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
total_size, total_size/1024./1024/1024)
total_size, total_size / 1024. / 1024 / 1024)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % (
total_size_inputs, total_size_inputs/1024./1024/1024)
total_size_inputs, total_size_inputs / 1024. / 1024 / 1024)
else:
hints.append(
......@@ -326,7 +326,7 @@ class Linker(object):
raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__)
## DELETEME ##
# DELETEME #
def make_function(self, unpack_single=True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
......@@ -350,8 +350,8 @@ class Linker(object):
def execute(*args):
def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \
% (takes, ['argument', 'arguments'][takes > 1], got)
return 'Function call takes exactly %i %s (%i given)' % (
takes, ['argument', 'arguments'][takes > 1], got)
if (len(args) != len(inputs)):
raise TypeError(e_arity(len(inputs), len(args)))
for arg, variable in izip(args, inputs):
......@@ -394,7 +394,7 @@ class Container(object):
"""
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
#self.r = r
# self.r = r
if isinstance(r, Type):
self.type = r
else:
......@@ -454,12 +454,11 @@ class Container(object):
deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo),
)
)
# Work around NumPy deepcopy of ndarray with 0 dimention that
# don't return an ndarray.
if (r.storage[0] is not None and
not self.type.is_valid_value(r.storage[0])):
not self.type.is_valid_value(r.storage[0])):
assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container.
......@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(allow_gc=self.allow_gc).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
# raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self.fgraph = fgraph
self.no_recycling = no_recycling
return self
......@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
for node in order:
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if (input in computed) and (input not in fgraph.outputs) and node == last_user[input]])
post_thunk_old_storage.append(
[storage_map[input]
for input in node.inputs
if (input in computed) and (
input not in fgraph.outputs) and (
node == last_user[input])])
if no_recycling is True:
# True seems like some special code for *everything*?? -JB
......@@ -855,7 +857,7 @@ class WrapLinker(Linker):
make_all += [l.make_all(**kwargs) for l in self.linkers[1:]]
fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all)
= zip(*make_all)
order_list0 = order_lists[0]
for order_list in order_lists[1:]:
......
差异被折叠。
......@@ -3,9 +3,11 @@ import linecache
import traceback
import sys
import numpy
from six import iteritems
from theano import config
from theano.compat import PY3
def simple_extract_stack(f=None, limit=None):
......@@ -435,3 +437,31 @@ def remove(predicate, coll):
[1, 3]
"""
return [x for x in coll if not predicate(x)]
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
......@@ -10,7 +10,7 @@ import numpy
from theano.compat import decode, decode_iter
from theano.gof import local_bitwidth
from theano.gof.cc import hash_from_file
from theano.gof.utils import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs,
std_include_dirs, dlimport,
Compiler,
......
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
def hash_from_sparse(data):
......
......@@ -2,7 +2,7 @@ import numpy
import theano
from theano.compat import izip
from theano.gof.cc import hash_from_code
from theano.gof.utils import hash_from_code
def hash_from_ndarray(data):
......
......@@ -233,16 +233,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py",
"gof/destroyhandler.py",
"gof/unify.py",
"gof/graph.py",
"gof/__init__.py",
"gof/cc.py",
"gof/opt.py",
"gof/link.py",
"gof/fg.py",
"gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论