提交 cbf36515 authored 作者: Frederic's avatar Frederic

pep8

上级 0d3dffac
......@@ -11,6 +11,7 @@ from theano.misc.ordered_set import OrderedSet
from fg import InconsistencyError
class ProtocolError(Exception):
"""Raised when FunctionGraph calls DestroyHandler callbacks in
an invalid way, for example, pruning or changing a node that has
......@@ -18,6 +19,7 @@ class ProtocolError(Exception):
"""
pass
def _contains_cycle(fgraph, orderings):
"""
......@@ -44,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
inputs = fgraph.inputs
outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C.
# specifically, it could be replaced with a wrapper
......@@ -55,7 +56,6 @@ def _contains_cycle(fgraph, orderings):
# this is performance-critical code. it is the largest single-function
# bottleneck when compiling large graphs.
assert isinstance(outputs, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings
......@@ -111,7 +111,7 @@ def _contains_cycle(fgraph, orderings):
# this is faster than calling get_parents
owner = var.owner
if owner:
parents = [ owner ]
parents = [owner]
else:
parents = []
......@@ -172,16 +172,16 @@ def _contains_cycle(fgraph, orderings):
# and increment the visited node count without double-counting
node = visitable.popleft()
visited += 1
for client in node_to_children.get(node,[]):
for client in node_to_children.get(node, []):
parent_counts[client] -= 1
# If all of a node's parents have been visited,
# it may now be visited too
if not parent_counts[client]:
visitable.append(client)
return visited != len(parent_counts)
def getroot(r, view_i):
"""
TODO: what is view_i ? based on add_impact's docstring, IG is guessing
......@@ -197,6 +197,7 @@ def getroot(r, view_i):
except KeyError:
return r
def add_impact(r, view_o, impact):
"""
In opposition to getroot, which finds the variable that is viewed *by* r, this function
......@@ -211,15 +212,17 @@ def add_impact(r, view_o, impact):
IG thinks so, based on reading the code. It looks like get_impact
does what this docstring said this function does.
"""
for v in view_o.get(r,[]):
for v in view_o.get(r, []):
impact.add(v)
add_impact(v, view_o, impact)
def get_impact(root, view_o):
impact = OrderedSet()
add_impact(root, view_o, impact)
return impact
def fast_inplace_check(inputs):
""" Return the variables in inputs that are posible candidate for as inputs of inplace operation
......@@ -227,12 +230,14 @@ def fast_inplace_check(inputs):
:param inputs: inputs Variable that you want to use as inplace destination
"""
fgraph = inputs[0].fgraph
protected_inputs = [f.protected for f in fgraph._features if isinstance(f,theano.compile.function_module.Supervisor)]
protected_inputs = sum(protected_inputs,[])#flatten the list
Supervisor = theano.compile.function_module.Supervisor
protected_inputs = [f.protected for f in fgraph._features
if isinstance(f, Supervisor)]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if
not isinstance(i,graph.Constant)
not isinstance(i, graph.Constant)
and not fgraph.destroyers(i)
and i not in protected_inputs]
return inputs
......@@ -294,14 +299,18 @@ if 0:
if self.fgraph is fgraph:
already_there = True
if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
raise Exception("A DestroyHandler instance can only serve"
" one FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
# FunctionGraph.attach_feature catches AlreadyThere
# and cancels the attachment
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present or in"
" conflict with another plugin.")
####### end of checking ############
......@@ -342,7 +351,7 @@ if 0:
def _build_droot_impact(self):
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
......@@ -352,7 +361,8 @@ if 0:
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
......@@ -394,7 +404,9 @@ if 0:
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
......@@ -402,7 +414,7 @@ if 0:
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0)
self.clients.setdefault(input, {}).setdefault(app, 0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
......@@ -460,7 +472,8 @@ if 0:
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
{}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -493,7 +506,8 @@ if 0:
ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles")
raise InconsistencyError(
"Dependency graph contains cycles")
else:
#James's Conjecture:
#If there are no destructive ops, then there can be no cycles.
......@@ -518,13 +532,14 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag,'indestructible', False) or \
illegal_destroy = [r for r in droot if
getattr(r.tag,'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
#print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
illegal_destroy)
raise InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies
for app in self.destroyers:
......@@ -569,15 +584,20 @@ if 0:
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same',
[])
assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated = OrderedSet(idx1 for idx0, idx1 in
tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
ignored = OrderedSet(idx1 for idx0, idx1
in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated
#print 'ignored', ignored
for i, input in enumerate(app.inputs):
......@@ -600,6 +620,7 @@ if 0:
return rval
class DestroyHandler(toolbox.Bookkeeper):
"""
The DestroyHandler class detects when a graph is impossible to evaluate
......@@ -642,7 +663,6 @@ class DestroyHandler(toolbox.Bookkeeper):
<unknown>
"""
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
......@@ -686,14 +706,18 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.fgraph is fgraph:
already_there = True
if self.fgraph is not None:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
raise Exception(
"A DestroyHandler instance can only serve one"
" FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############
......@@ -738,7 +762,8 @@ class DestroyHandler(toolbox.Bookkeeper):
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
input_impact = get_impact(input_root, self.view_o)
......@@ -768,7 +793,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import")
if app in self.debug_all_apps:
raise ProtocolError("double import")
self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
......@@ -777,9 +803,12 @@ class DestroyHandler(toolbox.Bookkeeper):
self.destroyers.add(app)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
......@@ -787,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, OrderedDict()).setdefault(app,0)
self.clients.setdefault(input, OrderedDict()).setdefault(app, 0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
......@@ -797,7 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
if app not in self.debug_all_apps:
raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
#UPDATE self.clients
......@@ -812,7 +842,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# deleted on_detach().
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -834,7 +865,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# considered 'outputs' of the graph.
pass
else:
if app not in self.debug_all_apps: raise ProtocolError("change without import")
if app not in self.debug_all_apps:
raise ProtocolError("change without import")
#UPDATE self.clients
self.clients[old_r][app] -= 1
......@@ -845,7 +877,8 @@ class DestroyHandler(toolbox.Bookkeeper):
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论