提交 cbf36515 authored 作者: Frederic's avatar Frederic

pep8

上级 0d3dffac
...@@ -11,6 +11,7 @@ from theano.misc.ordered_set import OrderedSet ...@@ -11,6 +11,7 @@ from theano.misc.ordered_set import OrderedSet
from fg import InconsistencyError from fg import InconsistencyError
class ProtocolError(Exception): class ProtocolError(Exception):
"""Raised when FunctionGraph calls DestroyHandler callbacks in """Raised when FunctionGraph calls DestroyHandler callbacks in
an invalid way, for example, pruning or changing a node that has an invalid way, for example, pruning or changing a node that has
...@@ -18,6 +19,7 @@ class ProtocolError(Exception): ...@@ -18,6 +19,7 @@ class ProtocolError(Exception):
""" """
pass pass
def _contains_cycle(fgraph, orderings): def _contains_cycle(fgraph, orderings):
""" """
...@@ -44,7 +46,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -44,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
inputs = fgraph.inputs inputs = fgraph.inputs
outputs = fgraph.outputs outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py # this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C. # reason: go faster, prepare for port to C.
# specifically, it could be replaced with a wrapper # specifically, it could be replaced with a wrapper
...@@ -55,7 +56,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -55,7 +56,6 @@ def _contains_cycle(fgraph, orderings):
# this is performance-critical code. it is the largest single-function # this is performance-critical code. it is the largest single-function
# bottleneck when compiling large graphs. # bottleneck when compiling large graphs.
assert isinstance(outputs, (tuple, list, deque)) assert isinstance(outputs, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings # TODO: For more speed - use a defaultdict for the orderings
...@@ -111,7 +111,7 @@ def _contains_cycle(fgraph, orderings): ...@@ -111,7 +111,7 @@ def _contains_cycle(fgraph, orderings):
# this is faster than calling get_parents # this is faster than calling get_parents
owner = var.owner owner = var.owner
if owner: if owner:
parents = [ owner ] parents = [owner]
else: else:
parents = [] parents = []
...@@ -172,16 +172,16 @@ def _contains_cycle(fgraph, orderings): ...@@ -172,16 +172,16 @@ def _contains_cycle(fgraph, orderings):
# and increment the visited node count without double-counting # and increment the visited node count without double-counting
node = visitable.popleft() node = visitable.popleft()
visited += 1 visited += 1
for client in node_to_children.get(node,[]): for client in node_to_children.get(node, []):
parent_counts[client] -= 1 parent_counts[client] -= 1
# If all of a node's parents have been visited, # If all of a node's parents have been visited,
# it may now be visited too # it may now be visited too
if not parent_counts[client]: if not parent_counts[client]:
visitable.append(client) visitable.append(client)
return visited != len(parent_counts) return visited != len(parent_counts)
def getroot(r, view_i): def getroot(r, view_i):
""" """
TODO: what is view_i ? based on add_impact's docstring, IG is guessing TODO: what is view_i ? based on add_impact's docstring, IG is guessing
...@@ -197,6 +197,7 @@ def getroot(r, view_i): ...@@ -197,6 +197,7 @@ def getroot(r, view_i):
except KeyError: except KeyError:
return r return r
def add_impact(r, view_o, impact): def add_impact(r, view_o, impact):
""" """
In opposition to getroot, which finds the variable that is viewed *by* r, this function In opposition to getroot, which finds the variable that is viewed *by* r, this function
...@@ -211,15 +212,17 @@ def add_impact(r, view_o, impact): ...@@ -211,15 +212,17 @@ def add_impact(r, view_o, impact):
IG thinks so, based on reading the code. It looks like get_impact IG thinks so, based on reading the code. It looks like get_impact
does what this docstring said this function does. does what this docstring said this function does.
""" """
for v in view_o.get(r,[]): for v in view_o.get(r, []):
impact.add(v) impact.add(v)
add_impact(v, view_o, impact) add_impact(v, view_o, impact)
def get_impact(root, view_o): def get_impact(root, view_o):
impact = OrderedSet() impact = OrderedSet()
add_impact(root, view_o, impact) add_impact(root, view_o, impact)
return impact return impact
def fast_inplace_check(inputs): def fast_inplace_check(inputs):
""" Return the variables in inputs that are posible candidate for as inputs of inplace operation """ Return the variables in inputs that are posible candidate for as inputs of inplace operation
...@@ -227,12 +230,14 @@ def fast_inplace_check(inputs): ...@@ -227,12 +230,14 @@ def fast_inplace_check(inputs):
:param inputs: inputs Variable that you want to use as inplace destination :param inputs: inputs Variable that you want to use as inplace destination
""" """
fgraph = inputs[0].fgraph fgraph = inputs[0].fgraph
protected_inputs = [f.protected for f in fgraph._features if isinstance(f,theano.compile.function_module.Supervisor)] Supervisor = theano.compile.function_module.Supervisor
protected_inputs = sum(protected_inputs,[])#flatten the list protected_inputs = [f.protected for f in fgraph._features
if isinstance(f, Supervisor)]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs) protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if inputs = [i for i in inputs if
not isinstance(i,graph.Constant) not isinstance(i, graph.Constant)
and not fgraph.destroyers(i) and not fgraph.destroyers(i)
and i not in protected_inputs] and i not in protected_inputs]
return inputs return inputs
...@@ -294,14 +299,18 @@ if 0: ...@@ -294,14 +299,18 @@ if 0:
if self.fgraph is fgraph: if self.fgraph is fgraph:
already_there = True already_there = True
if self.fgraph not in [None, fgraph]: if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)") raise Exception("A DestroyHandler instance can only serve"
" one FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'): for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
already_there = True already_there = True
if already_there: if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment # FunctionGraph.attach_feature catches AlreadyThere
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.") # and cancels the attachment
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present or in"
" conflict with another plugin.")
####### end of checking ############ ####### end of checking ############
...@@ -342,7 +351,7 @@ if 0: ...@@ -342,7 +351,7 @@ if 0:
def _build_droot_impact(self): def _build_droot_impact(self):
droot = {} # destroyed view + nonview variables -> foundation droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply root_destroyer = {} # root -> destroyer apply
for app in self.destroyers: for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items(): for output_idx, input_idx_list in app.op.destroy_map.items():
...@@ -352,7 +361,8 @@ if 0: ...@@ -352,7 +361,8 @@ if 0:
input = app.inputs[input_idx] input = app.inputs[input_idx]
input_root = getroot(input, self.view_i) input_root = getroot(input, self.view_i)
if input_root in droot: if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root) raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root droot[input_root] = input_root
root_destroyer[input_root] = app root_destroyer[input_root] = app
#input_impact = set([input_root]) #input_impact = set([input_root])
...@@ -394,7 +404,9 @@ if 0: ...@@ -394,7 +404,9 @@ if 0:
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op)) raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx] o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]] i = app.inputs[i_idx_list[0]]
self.view_i[o] = i self.view_i[o] = i
...@@ -402,7 +414,7 @@ if 0: ...@@ -402,7 +414,7 @@ if 0:
# update self.clients # update self.clients
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0) self.clients.setdefault(input, {}).setdefault(app, 0)
self.clients[input][app] += 1 self.clients[input][app] += 1
for i, output in enumerate(app.outputs): for i, output in enumerate(app.outputs):
...@@ -460,7 +472,8 @@ if 0: ...@@ -460,7 +472,8 @@ if 0:
self.clients[new_r][app] += 1 self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
{}).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -493,7 +506,8 @@ if 0: ...@@ -493,7 +506,8 @@ if 0:
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords): if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles") raise InconsistencyError(
"Dependency graph contains cycles")
else: else:
#James's Conjecture: #James's Conjecture:
#If there are no destructive ops, then there can be no cycles. #If there are no destructive ops, then there can be no cycles.
...@@ -518,13 +532,14 @@ if 0: ...@@ -518,13 +532,14 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if \ illegal_destroy = [r for r in droot if
getattr(r.tag,'indestructible', False) or \ getattr(r.tag,'indestructible', False) or
isinstance(r, graph.Constant)] isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
#print 'destroying illegally' #print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible variables: %s" % raise InconsistencyError(
illegal_destroy) "Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies # add destroyed variable clients as computational dependencies
for app in self.destroyers: for app in self.destroyers:
...@@ -569,15 +584,20 @@ if 0: ...@@ -569,15 +584,20 @@ if 0:
#CHECK FOR INPUT ALIASING #CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same',
[])
assert isinstance(tolerate_same, list) assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet(idx1 for idx0, idx1 in
if idx0 == destroyed_idx) tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list) assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1
if idx0 == destroyed_idx) in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated #print 'tolerated', tolerated
#print 'ignored', ignored #print 'ignored', ignored
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
...@@ -600,6 +620,7 @@ if 0: ...@@ -600,6 +620,7 @@ if 0:
return rval return rval
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper):
""" """
The DestroyHandler class detects when a graph is impossible to evaluate The DestroyHandler class detects when a graph is impossible to evaluate
...@@ -642,7 +663,6 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -642,7 +663,6 @@ class DestroyHandler(toolbox.Bookkeeper):
<unknown> <unknown>
""" """
def __init__(self, do_imports_on_attach=True): def __init__(self, do_imports_on_attach=True):
self.fgraph = None self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach self.do_imports_on_attach = do_imports_on_attach
...@@ -686,14 +706,18 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -686,14 +706,18 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.fgraph is fgraph: if self.fgraph is fgraph:
already_there = True already_there = True
if self.fgraph is not None: if self.fgraph is not None:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)") raise Exception(
"A DestroyHandler instance can only serve one"
" FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'): for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
already_there = True already_there = True
if already_there: if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.") raise toolbox.AlreadyThere(
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############ ####### Annotate the FunctionGraph ############
...@@ -738,7 +762,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -738,7 +762,8 @@ class DestroyHandler(toolbox.Bookkeeper):
input = app.inputs[input_idx] input = app.inputs[input_idx]
input_root = getroot(input, self.view_i) input_root = getroot(input, self.view_i)
if input_root in droot: if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root) raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root droot[input_root] = input_root
root_destroyer[input_root] = app root_destroyer[input_root] = app
input_impact = get_impact(input_root, self.view_o) input_impact = get_impact(input_root, self.view_o)
...@@ -768,7 +793,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -768,7 +793,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import") if app in self.debug_all_apps:
raise ProtocolError("double import")
self.debug_all_apps.add(app) self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) #print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...@@ -777,9 +803,12 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -777,9 +803,12 @@ class DestroyHandler(toolbox.Bookkeeper):
self.destroyers.add(app) self.destroyers.add(app)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op)) raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx] o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]] i = app.inputs[i_idx_list[0]]
self.view_i[o] = i self.view_i[o] = i
...@@ -787,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -787,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# update self.clients # update self.clients
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
self.clients.setdefault(input, OrderedDict()).setdefault(app,0) self.clients.setdefault(input, OrderedDict()).setdefault(app, 0)
self.clients[input][app] += 1 self.clients[input][app] += 1
for i, output in enumerate(app.outputs): for i, output in enumerate(app.outputs):
...@@ -797,7 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -797,7 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_prune(self, fgraph, app, reason): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import") if app not in self.debug_all_apps:
raise ProtocolError("prune without import")
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
#UPDATE self.clients #UPDATE self.clients
...@@ -812,7 +842,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -812,7 +842,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# deleted on_detach(). # deleted on_detach().
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -834,7 +865,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -834,7 +865,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
pass pass
else: else:
if app not in self.debug_all_apps: raise ProtocolError("change without import") if app not in self.debug_all_apps:
raise ProtocolError("change without import")
#UPDATE self.clients #UPDATE self.clients
self.clients[old_r][app] -= 1 self.clients[old_r][app] -= 1
...@@ -845,7 +877,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -845,7 +877,8 @@ class DestroyHandler(toolbox.Bookkeeper):
self.clients[new_r][app] += 1 self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论