提交 cba9c812 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3130 from harlouci/flake8_gof

Flake8 gof
...@@ -177,7 +177,7 @@ def get_config_md5(): ...@@ -177,7 +177,7 @@ def get_config_md5():
""" """
all_opts = sorted([c for c in _config_var_list if c.in_c_key], all_opts = sorted([c for c in _config_var_list if c.in_c_key],
key=lambda cv: cv.fullname) key=lambda cv: cv.fullname)
return theano.gof.cc.hash_from_code('\n'.join( return theano.gof.utils.hash_from_code('\n'.join(
['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts])) ['%s = %s' % (cv.fullname, cv.__get__()) for cv in all_opts]))
......
差异被折叠。
差异被折叠。
...@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
""" """
# These are lists of Variable instances # These are lists of Variable instances
inputs = fgraph.inputs
outputs = fgraph.outputs outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py # this is hard-coded reimplementation of functions from graph.py
...@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key # (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython) # is not in the dictionary, at least in CPython)
iset = set(inputs)
# IG: I tried converting parent_counts to use an id for the key, # IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys. # so that the dict would do reference counting on its keys.
# This caused a slowdown. # This caused a slowdown.
...@@ -236,9 +233,9 @@ def fast_inplace_check(inputs): ...@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs.extend(fgraph.outputs) protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if inputs = [i for i in inputs if
not isinstance(i, graph.Constant) not isinstance(i, graph.Constant) and
and not fgraph.destroyers(i) not fgraph.destroyers(i) and
and i not in protected_inputs] i not in protected_inputs]
return inputs return inputs
if 0: if 0:
...@@ -293,7 +290,7 @@ if 0: ...@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
""" """
####### Do the checking ########### # Do the checking #
already_there = False already_there = False
if self.fgraph not in [None, fgraph]: if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve" raise Exception("A DestroyHandler instance can only serve"
...@@ -309,7 +306,7 @@ if 0: ...@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in" "DestroyHandler feature is already present or in"
" conflict with another plugin.") " conflict with another plugin.")
####### end of checking ############ # end of checking #
def get_destroyers_of(r): def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact() droot, impact, root_destroyer = self.refresh_droot_impact()
...@@ -362,8 +359,8 @@ if 0: ...@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of %s" % input_root) "Multiple destroyers of %s" % input_root)
droot[input_root] = input_root droot[input_root] = input_root
root_destroyer[input_root] = app root_destroyer[input_root] = app
#input_impact = set([input_root]) # input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact) # add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o) input_impact = get_impact(input_root, self.view_o)
for v in input_impact: for v in input_impact:
assert v not in droot assert v not in droot
...@@ -390,7 +387,7 @@ if 0: ...@@ -390,7 +387,7 @@ if 0:
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import") # if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app) # self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...@@ -421,7 +418,7 @@ if 0: ...@@ -421,7 +418,7 @@ if 0:
def on_prune(self, fgraph, app, reason): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import") # if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app) # self.debug_all_apps.remove(app)
# UPDATE self.clients # UPDATE self.clients
...@@ -458,7 +455,7 @@ if 0: ...@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
pass pass
else: else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import") # if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients # UPDATE self.clients
self.clients[old_r][app] -= 1 self.clients[old_r][app] -= 1
...@@ -529,9 +526,10 @@ if 0: ...@@ -529,9 +526,10 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if illegal_destroy = [
getattr(r.tag, 'indestructible', False) or r for r in droot if
isinstance(r, graph.Constant)] getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
# print 'destroying illegally' # print 'destroying illegally'
raise InconsistencyError( raise InconsistencyError(
...@@ -603,7 +601,7 @@ if 0: ...@@ -603,7 +601,7 @@ if 0:
if input in root_impact \ if input in root_impact \
and (i not in tolerated or input is not destroyed_variable): and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)" raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i)) % (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
...@@ -621,7 +619,7 @@ if 0: ...@@ -621,7 +619,7 @@ if 0:
return rval return rval
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper): # noqa
""" """
The DestroyHandler class detects when a graph is impossible to evaluate The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations. because of aliasing and destructive operations.
...@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
""" """
####### Do the checking ########### # Do the checking #
already_there = False already_there = False
if self.fgraph is fgraph: if self.fgraph is fgraph:
already_there = True already_there = True
...@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present" "DestroyHandler feature is already present"
" or in conflict with another plugin.") " or in conflict with another plugin.")
####### Annotate the FunctionGraph ############ # Annotate the FunctionGraph #
self.unpickle(fgraph) self.unpickle(fgraph)
fgraph.destroy_handler = self fgraph.destroy_handler = self
...@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if \ illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or \ getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)] isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
raise InconsistencyError("Attempting to destroy indestructible variables: %s" % raise InconsistencyError(
illegal_destroy) "Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies # add destroyed variable clients as computational dependencies
for app in self.destroyers: for app in self.destroyers:
...@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING # CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list) assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list) assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
# print 'tolerated', tolerated # print 'tolerated', tolerated
# print 'ignored', ignored # print 'ignored', ignored
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
if i in ignored: if i in ignored:
continue continue
if input in root_impact \ if input in root_impact \
and (i not in tolerated or input is not destroyed_variable): and (i not in tolerated or
input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)" raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i)) % (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
......
...@@ -13,7 +13,6 @@ from theano.gof import graph ...@@ -13,7 +13,6 @@ from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof import toolbox from theano.gof import toolbox
from theano import config from theano import config
import warnings
from theano.compat import OrderedDict from theano.compat import OrderedDict
from six import iteritems, itervalues from six import iteritems, itervalues
...@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet ...@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
class CachedConstantError(Exception): class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant """An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this that is cached. This should not happen as the user can reuse this
...@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2): ...@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
self.variable_locks = {} self.variable_locks = {}
self.profile = None self.profile = None
### Setup a Variable ### # Setup a Variable #
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph # sets up r so it belongs to this fgraph
if getattr(r, 'cached', False): if getattr(r, 'cached', False):
...@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2): ...@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
" graph that has a cached constant. This should not happen." " graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph.") " Clone the graph before building the FunctionGraph.")
if (hasattr(r, 'fgraph') and if (hasattr(r, 'fgraph') and
r.fgraph is not None and r.fgraph is not None and
r.fgraph is not self): r.fgraph is not self):
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self r.fgraph = self
r.clients = [] r.clients = []
#self.execute_callbacks('on_setup_variable', r) # self.execute_callbacks('on_setup_variable', r)
def __setup_node__(self, node): def __setup_node__(self, node):
# sets up node so it belongs to this fgraph # sets up node so it belongs to this fgraph
...@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2): ...@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
str(node.op), str(node.op.destroy_map))) str(node.op), str(node.op.destroy_map)))
node.fgraph = self node.fgraph = self
node.deps = {} node.deps = {}
#self.execute_callbacks('on_setup_node', node) # self.execute_callbacks('on_setup_node', node)
def disown(self): def disown(self):
""" WRITEME """ WRITEME
...@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2): ...@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
self.inputs = None self.inputs = None
self.outputs = None self.outputs = None
### clients ### # clients #
def clients(self, r): def clients(self, r):
""" """
Set of all the (node, i) pairs such that node.inputs[i] is r. Set of all the (node, i) pairs such that node.inputs[i] is r.
...@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2): ...@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
if set(r.clients).intersection(set(new_clients)): if set(r.clients).intersection(set(new_clients)):
print('ERROR: clients intersect!', file=sys.stderr) print('ERROR: clients intersect!', file=sys.stderr)
print(' RCLIENTS of', r, [(n, i, type(n), id(n)) print(' RCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in r.clients], file=sys.stderr) for n, i in r.clients], file=sys.stderr)
print(' NCLIENTS of', r, [(n, i, type(n), id(n)) print(' NCLIENTS of', r, [(n, i, type(n), id(n))
for n, i in new_clients], file=sys.stderr) for n, i in new_clients], file=sys.stderr)
assert not set(r.clients).intersection(set(new_clients)) assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
...@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2): ...@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
return True return True
return False return False
### import ### # import #
def __import_r__(self, variable, reason): def __import_r__(self, variable, reason):
global NullType global NullType
if NullType is None: if NullType is None:
...@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2): ...@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
if hasattr(r, 'fgraph') and r.fgraph is not self: if hasattr(r, 'fgraph') and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r) raise Exception("%s is already owned by another fgraph" % r)
if (r.owner is None and if (r.owner is None and
not isinstance(r, graph.Constant) and not isinstance(r, graph.Constant) and
r not in self.inputs): r not in self.inputs):
# Verbose error message # Verbose error message
# Show a complete chain of variables from the missing input to an output # Show a complete chain of variables from the missing input to an output
if config.exception_verbosity == 'high': if config.exception_verbosity == 'high':
...@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2): ...@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
assert node.fgraph is self assert node.fgraph is self
self.execute_callbacks('on_import', node, reason) self.execute_callbacks('on_import', node, reason)
### prune ### # prune #
def __prune_r__(self, variable, reason=None): def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore: """Should be called for variable that aren't used anymore:
len(var.clients) == 0 len(var.clients) == 0
...@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2): ...@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
self.__remove_clients__(input, [(apply_node, i)], reason=reason) self.__remove_clients__(input, [(apply_node, i)], reason=reason)
# self.__prune_r__(apply_node.inputs) # self.__prune_r__(apply_node.inputs)
### change input ### # change input #
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
"""WRITEME """WRITEME
Changes node.inputs[i] to new_r. Changes node.inputs[i] to new_r.
...@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2): ...@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
if prune: if prune:
self.__prune_r__(r, reason=reason) self.__prune_r__(r, reason=reason)
### replace ### # replace #
def replace(self, r, new_r, reason=None, verbose=None): def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME """ WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph. This is the main interface to manipulate the subgraph in FunctionGraph.
...@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2): ...@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
if detach is not None: if detach is not None:
detach(self) detach(self)
### callback utils ### # callback utils #
def execute_callbacks(self, name, *args, **kwargs): def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME """WRITEME
Calls Calls
...@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2): ...@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args) d[feature] = fn(*args)
return d return d
### misc ### # misc #
def toposort(self): def toposort(self):
"""WRITEME """WRITEME
Returns an ordering of the graph's Apply nodes such that: Returns an ordering of the graph's Apply nodes such that:
...@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2): ...@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
missing, excess) missing, excess)
for variable in variables: for variable in variables:
if (variable.owner is None and if (variable.owner is None and
variable not in self.inputs and variable not in self.inputs and
not isinstance(variable, graph.Constant)): not isinstance(variable, graph.Constant)):
raise Exception("Undeclared input.", variable) raise Exception("Undeclared input.", variable)
if variable.fgraph is not self: if variable.fgraph is not self:
raise Exception("Variable should belong to the FunctionGraph.", raise Exception("Variable should belong to the FunctionGraph.",
...@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2): ...@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
### clone ### # clone #
def clone(self, check_integrity=True): def clone(self, check_integrity=True):
"""WRITEME""" """WRITEME"""
return self.clone_get_equiv(check_integrity)[0] return self.clone_get_equiv(check_integrity)[0]
......
...@@ -7,14 +7,14 @@ import traceback ...@@ -7,14 +7,14 @@ import traceback
import numpy import numpy
import theano import theano
from theano.compat import PY3, izip from theano.compat import izip
from six import reraise from six import reraise
from six.moves import StringIO from six.moves import StringIO
from theano.gof import utils from theano.gof import utils
from theano.gof import graph from theano.gof import graph
from theano.gof.type import Type from theano.gof.type import Type
from .utils import MethodNotDefined, undef from .utils import undef
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else: else:
detailed_err_msg += "\n" detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % ( detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
total_size, total_size/1024./1024/1024) total_size, total_size / 1024. / 1024 / 1024)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % ( detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f BG\n" % (
total_size_inputs, total_size_inputs/1024./1024/1024) total_size_inputs, total_size_inputs / 1024. / 1024 / 1024)
else: else:
hints.append( hints.append(
...@@ -326,7 +326,7 @@ class Linker(object): ...@@ -326,7 +326,7 @@ class Linker(object):
raise utils.MethodNotDefined("make_thunk", type(self), raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__) self.__class__.__name__)
## DELETEME ## # DELETEME #
def make_function(self, unpack_single=True, **kwargs): def make_function(self, unpack_single=True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
...@@ -350,8 +350,8 @@ class Linker(object): ...@@ -350,8 +350,8 @@ class Linker(object):
def execute(*args): def execute(*args):
def e_arity(takes, got): def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \ return 'Function call takes exactly %i %s (%i given)' % (
% (takes, ['argument', 'arguments'][takes > 1], got) takes, ['argument', 'arguments'][takes > 1], got)
if (len(args) != len(inputs)): if (len(args) != len(inputs)):
raise TypeError(e_arity(len(inputs), len(args))) raise TypeError(e_arity(len(inputs), len(args)))
for arg, variable in izip(args, inputs): for arg, variable in izip(args, inputs):
...@@ -394,7 +394,7 @@ class Container(object): ...@@ -394,7 +394,7 @@ class Container(object):
""" """
if not isinstance(storage, list) or not len(storage) >= 1: if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one") raise TypeError("storage must be a list of length at least one")
#self.r = r # self.r = r
if isinstance(r, Type): if isinstance(r, Type):
self.type = r self.type = r
else: else:
...@@ -454,12 +454,11 @@ class Container(object): ...@@ -454,12 +454,11 @@ class Container(object):
deepcopy(self.strict, memo=memo), deepcopy(self.strict, memo=memo),
deepcopy(self.allow_downcast, memo=memo), deepcopy(self.allow_downcast, memo=memo),
deepcopy(self.name, memo=memo), deepcopy(self.name, memo=memo),
) )
# Work around NumPy deepcopy of ndarray with 0 dimention that # Work around NumPy deepcopy of ndarray with 0 dimention that
# don't return an ndarray. # don't return an ndarray.
if (r.storage[0] is not None and if (r.storage[0] is not None and
not self.type.is_valid_value(r.storage[0])): not self.type.is_valid_value(r.storage[0])):
assert not data_was_in_memo assert not data_was_in_memo
assert self.type.is_valid_value(self.storage[0]) assert self.type.is_valid_value(self.storage[0])
# This should also work for read only container. # This should also work for read only container.
...@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker): ...@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(allow_gc=self.allow_gc).accept(fgraph, no_recycling) return type(self)(allow_gc=self.allow_gc).accept(fgraph, no_recycling)
#raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.") # raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
...@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker): ...@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
for node in order: for node in order:
if self.allow_gc: if self.allow_gc:
post_thunk_old_storage.append([storage_map[input] post_thunk_old_storage.append(
for input in node.inputs [storage_map[input]
if (input in computed) and (input not in fgraph.outputs) and node == last_user[input]]) for input in node.inputs
if (input in computed) and (
input not in fgraph.outputs) and (
node == last_user[input])])
if no_recycling is True: if no_recycling is True:
# True seems like some special code for *everything*?? -JB # True seems like some special code for *everything*?? -JB
...@@ -855,7 +857,7 @@ class WrapLinker(Linker): ...@@ -855,7 +857,7 @@ class WrapLinker(Linker):
make_all += [l.make_all(**kwargs) for l in self.linkers[1:]] make_all += [l.make_all(**kwargs) for l in self.linkers[1:]]
fns, input_lists, output_lists, thunk_lists, order_lists \ fns, input_lists, output_lists, thunk_lists, order_lists \
= zip(*make_all) = zip(*make_all)
order_list0 = order_lists[0] order_list0 = order_lists[0]
for order_list in order_lists[1:]: for order_list in order_lists[1:]:
......
差异被折叠。
...@@ -3,9 +3,11 @@ import linecache ...@@ -3,9 +3,11 @@ import linecache
import traceback import traceback
import sys import sys
import numpy
from six import iteritems from six import iteritems
from theano import config from theano import config
from theano.compat import PY3
def simple_extract_stack(f=None, limit=None): def simple_extract_stack(f=None, limit=None):
...@@ -435,3 +437,31 @@ def remove(predicate, coll): ...@@ -435,3 +437,31 @@ def remove(predicate, coll):
[1, 3] [1, 3]
""" """
return [x for x in coll if not predicate(x)] return [x for x in coll if not predicate(x)]
if PY3:
import hashlib
def hash_from_code(msg):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if isinstance(msg, str):
msg = msg.encode()
# Python 3 does not like module names that start with
# a digit.
return 'm' + hashlib.md5(msg).hexdigest()
else:
import hashlib
def hash_from_code(msg):
try:
return hashlib.md5(msg).hexdigest()
except TypeError:
assert isinstance(msg, numpy.ndarray)
return hashlib.md5(numpy.getbuffer(msg)).hexdigest()
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
return hash_from_code(open(file_path, 'rb').read())
...@@ -10,7 +10,7 @@ import numpy ...@@ -10,7 +10,7 @@ import numpy
from theano.compat import decode, decode_iter from theano.compat import decode, decode_iter
from theano.gof import local_bitwidth from theano.gof import local_bitwidth
from theano.gof.cc import hash_from_file from theano.gof.utils import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs, from theano.gof.cmodule import (std_libs, std_lib_dirs,
std_include_dirs, dlimport, std_include_dirs, dlimport,
Compiler, Compiler,
......
from theano.gof.cc import hash_from_code from theano.gof.utils import hash_from_code
def hash_from_sparse(data): def hash_from_sparse(data):
......
...@@ -2,7 +2,7 @@ import numpy ...@@ -2,7 +2,7 @@ import numpy
import theano import theano
from theano.compat import izip from theano.compat import izip
from theano.gof.cc import hash_from_code from theano.gof.utils import hash_from_code
def hash_from_ndarray(data): def hash_from_ndarray(data):
......
...@@ -233,16 +233,10 @@ whitelist_flake8 = [ ...@@ -233,16 +233,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py", "sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py", "sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py", "sparse/sandbox/sp.py",
"gof/destroyhandler.py",
"gof/unify.py", "gof/unify.py",
"gof/graph.py", "gof/graph.py",
"gof/__init__.py", "gof/__init__.py",
"gof/cc.py",
"gof/opt.py",
"gof/link.py",
"gof/fg.py",
"gof/op.py", "gof/op.py",
"gof/cmodule.py",
"gof/tests/test_cmodule.py", "gof/tests/test_cmodule.py",
"gof/tests/test_destroyhandler.py", "gof/tests/test_destroyhandler.py",
"gof/tests/test_opt.py", "gof/tests/test_opt.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论