提交 8c4e330d authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 of theano/gof/destroyhandler.py

上级 b129fb77
...@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
""" """
# These are lists of Variable instances # These are lists of Variable instances
inputs = fgraph.inputs
outputs = fgraph.outputs outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py # this is hard-coded reimplementation of functions from graph.py
...@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key # (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython) # is not in the dictionary, at least in CPython)
iset = set(inputs)
# IG: I tried converting parent_counts to use an id for the key, # IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys. # so that the dict would do reference counting on its keys.
# This caused a slowdown. # This caused a slowdown.
...@@ -236,9 +233,9 @@ def fast_inplace_check(inputs): ...@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs.extend(fgraph.outputs) protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if inputs = [i for i in inputs if
not isinstance(i, graph.Constant) not isinstance(i, graph.Constant) and
and not fgraph.destroyers(i) not fgraph.destroyers(i) and
and i not in protected_inputs] i not in protected_inputs]
return inputs return inputs
if 0: if 0:
...@@ -293,7 +290,7 @@ if 0: ...@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
""" """
####### Do the checking ########### # Do the checking #
already_there = False already_there = False
if self.fgraph not in [None, fgraph]: if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve" raise Exception("A DestroyHandler instance can only serve"
...@@ -309,7 +306,7 @@ if 0: ...@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in" "DestroyHandler feature is already present or in"
" conflict with another plugin.") " conflict with another plugin.")
####### end of checking ############ # end of checking #
def get_destroyers_of(r): def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact() droot, impact, root_destroyer = self.refresh_droot_impact()
...@@ -362,8 +359,8 @@ if 0: ...@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of %s" % input_root) "Multiple destroyers of %s" % input_root)
droot[input_root] = input_root droot[input_root] = input_root
root_destroyer[input_root] = app root_destroyer[input_root] = app
#input_impact = set([input_root]) # input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact) # add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o) input_impact = get_impact(input_root, self.view_o)
for v in input_impact: for v in input_impact:
assert v not in droot assert v not in droot
...@@ -390,7 +387,7 @@ if 0: ...@@ -390,7 +387,7 @@ if 0:
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import") # if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app) # self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...@@ -421,7 +418,7 @@ if 0: ...@@ -421,7 +418,7 @@ if 0:
def on_prune(self, fgraph, app, reason): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import") # if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app) # self.debug_all_apps.remove(app)
# UPDATE self.clients # UPDATE self.clients
...@@ -458,7 +455,7 @@ if 0: ...@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
pass pass
else: else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import") # if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients # UPDATE self.clients
self.clients[old_r][app] -= 1 self.clients[old_r][app] -= 1
...@@ -529,9 +526,10 @@ if 0: ...@@ -529,9 +526,10 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if illegal_destroy = [
getattr(r.tag, 'indestructible', False) or r for r in droot if
isinstance(r, graph.Constant)] getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
# print 'destroying illegally' # print 'destroying illegally'
raise InconsistencyError( raise InconsistencyError(
...@@ -603,7 +601,7 @@ if 0: ...@@ -603,7 +601,7 @@ if 0:
if input in root_impact \ if input in root_impact \
and (i not in tolerated or input is not destroyed_variable): and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)" raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i)) % (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
...@@ -621,7 +619,7 @@ if 0: ...@@ -621,7 +619,7 @@ if 0:
return rval return rval
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper): # noqa
""" """
The DestroyHandler class detects when a graph is impossible to evaluate The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations. because of aliasing and destructive operations.
...@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
""" """
####### Do the checking ########### # Do the checking #
already_there = False already_there = False
if self.fgraph is fgraph: if self.fgraph is fgraph:
already_there = True already_there = True
...@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present" "DestroyHandler feature is already present"
" or in conflict with another plugin.") " or in conflict with another plugin.")
####### Annotate the FunctionGraph ############ # Annotate the FunctionGraph #
self.unpickle(fgraph) self.unpickle(fgraph)
fgraph.destroy_handler = self fgraph.destroy_handler = self
...@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if \ illegal_destroy = [r for r in droot if
getattr(r.tag, 'indestructible', False) or \ getattr(r.tag, 'indestructible', False) or
isinstance(r, graph.Constant)] isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
raise InconsistencyError("Attempting to destroy indestructible variables: %s" % raise InconsistencyError(
illegal_destroy) "Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies # add destroyed variable clients as computational dependencies
for app in self.destroyers: for app in self.destroyers:
...@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING # CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same', [])
assert isinstance(tolerate_same, list) assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list) assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
# print 'tolerated', tolerated # print 'tolerated', tolerated
# print 'ignored', ignored # print 'ignored', ignored
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
if i in ignored: if i in ignored:
continue continue
if input in root_impact \ if input in root_impact \
and (i not in tolerated or input is not destroyed_variable): and (i not in tolerated or
input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)" raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i)) % (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
......
...@@ -240,7 +240,6 @@ whitelist_flake8 = [ ...@@ -240,7 +240,6 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py", "sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py", "sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py", "sparse/sandbox/sp.py",
"gof/destroyhandler.py",
"gof/unify.py", "gof/unify.py",
"gof/graph.py", "gof/graph.py",
"gof/__init__.py", "gof/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论