提交 213db8fd authored 作者: James Bergstra's avatar James Bergstra

rewrote DestroyHandler and moved it to destroyhandler.py, moved view_roots to graph.py

上级 99a92131
from collections import defaultdict
import toolbox
import graph
from env import InconsistencyError
class ProtocolError(Exception): pass
class DestroyHandler(toolbox.Bookkeeper):
def __init__(self):
self.map = {}
def on_attach(self, env):
dh = self.map.setdefault(env, DestroyHandlerHelper2())
dh.on_attach(env)
def on_detach(self, env):
self.map[env].on_detach(env)
def on_import(self, env, op):
self.map[env].on_import(env, op)
def on_prune(self, env, op):
self.map[env].on_prune(env, op)
def on_change_input(self, env, node, i, r, new_r):
self.map[env].on_change_input(env, node, i, r, new_r)
def validate(self, env):
self.map[env].validate(env)
def orderings(self, env):
return self.map[env].orderings(env)
class DestroyHandlerHelper2(toolbox.Bookkeeper):
def __init__(self):
self.env = None
def on_attach(self, env):
#boilerplate from old implementation
if self.env is not None:
raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr):
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
def get_destroyers(r):
d_of = self.get_destroyer_of(r)
if d_of:
return [d_of]
else:
return []
env.destroyers = get_destroyers
env.destroy_handler = self
self.env = env
self.destroyers = set() #set of Apply instances with non-null destroy_map
self.view_i = {} # result -> result
self.view_o = {} # result -> set of results
#clients: how many times does an apply use a given result
self.clients = {} # result -> apply -> ninputs
self.debug_all_apps = set()
toolbox.Bookkeeper.on_attach(self, env)
def build_droot_impact(self):
droot = {} # destroyed view + nonview results -> foundation
impact = {} # destroyed nonview result -> it + all views of it
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
def getroot(r):
try:
return getroot(self.view_i[r])
except KeyError:
return r
input_root = getroot(input)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
impact[input_root] = set([input_root])
def build_stuff(r):
for v in self.view_o.get(r,[]):
assert v not in droot
droot[v] = input_root
impact[input_root].add(v)
build_stuff(v)
build_stuff(input_root)
return droot, impact, root_destroyer
def get_destroyer_of(self, r):
droot, impact, root_destroyer = self.build_droot_impact()
for root in impact:
if r in impact[root]:
return root_destroyer[root]
def on_detach(self, env):
if env is not self.env:
raise Exception("detaching wrong env", env)
del self.destroyers
del self.view_i
del self.view_o
del self.clients
assert self.env.destroyer_handler is self
delattr(self.env, 'destroyers')
delattr(self.env, 'destroy_handler')
self.env = None
def on_import(self, env, app):
"""Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import")
self.debug_all_apps.add(app)
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}):
self.destroyers.add(app)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
self.view_o.setdefault(i,set()).add(o)
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
self.clients.setdefault(output, {})
def on_prune(self, env, app):
"""Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
#UPDATE self.clients
for i, input in enumerate(app.inputs):
del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}):
self.destroyers.remove(app)
# Note: leaving empty client dictionaries in the struct.
# Why? It's a pain to remove them. I think they aren't doing any harm, they will be
# deleted on_detach().
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
del self.view_i[o]
self.view_o[i].remove(o)
if not self.view_o[i]:
del self.view_o[i]
def on_change_input(self, env, app, i, old_r, new_r):
"""app.inputs[i] changed from old_r to new_r """
if app == 'output':
# app == 'output' is special key that means Env is redefining which nodes are being
# considered 'outputs' of the graph.
pass
else:
if app not in self.debug_all_apps: raise ProtocolError("change without import")
#UPDATE self.clients
self.clients[old_r][app] -= 1
if self.clients[old_r][app] == 0:
del self.clients[old_r][app]
self.clients.setdefault(new_r,{}).setdefault(app,0)
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
i_idx = i_idx_list[0]
output = app.outputs[o_idx]
if i_idx == i:
if app.inputs[i_idx] is not new_r:
raise ProtocolError("wrong new_r on change")
self.view_i[output] = new_r
self.view_o[old_r].remove(output)
if not self.view_o[old_r]:
del self.view_o[old_r]
self.view_o.setdefault(new_r,set()).add(output)
def validate(self, env):
"""Return None
Raise InconsistencyError when
a) orderings() raises an error
b) orderings cannot be topologically sorted.
"""
#print '\nVALIDATE'
if self.destroyers:
try:
ords = self.orderings(env)
except Exception, e:
#print 'orderings failed with:', type(e), e.args
raise
#print 'orderings:', ords
try:
graph.io_toposort(env.inputs, env.outputs, ords)
except ValueError, e:
#print 'not passing.', ords
if 'cycles' in str(e):
raise InconsistencyError("Dependency graph contains cycles")
else:
raise
#print 'passing...', ords
else:
#James's Conjecture:
#If there are no destructive ops, then there can be no cycles.
pass
return True
def orderings(self, env):
"""Return orderings induced by destructive operations.
Raise InconsistencyError when
a) attempting to destroy indestructable result, or
b) attempting to destroy a value multiple times, or
c) an Apply destroys (illegally) one of its own inputs by aliasing
"""
rval = {}
if self.destroyers:
# BUILD DATA STRUCTURES
# CHECK for multiple destructions during construction of variables
droot, impact, __ignore = self.build_droot_impact()
#print "droot", droot
#print "impact", impact
#print "view_i", self.view_i
#print "view_o", self.view_o
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag,'indestructible', False) or \
isinstance(r, graph.Constant)]
if illegal_destroy:
#print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible results: %s" %
illegal_destroy)
# add destroyed result clients as computational dependencies
for app in self.destroyers:
# for each destroyed input...
for output_idx, input_idx_list in app.op.destroy_map.items():
destroyed_idx = input_idx_list[0]
destroyed_result = app.inputs[destroyed_idx]
root = droot[destroyed_result]
root_impact = impact[root]
# we generally want to put all clients of things which depend on root
# as pre-requisites of app.
# But, app is itself one such client!
# App will always be a client of the node we're destroying
# (destroyed_result, but the tricky thing is when it is also a client of
# *another result* viewing on the root. Generally this is illegal, (e.g.,
# add_inplace(x, x.T). In some special cases though, the in-place op will
# actually be able to work properly with multiple destroyed inputs (e.g,
# add_inplace(x, x). An Op that can still work in this case should declare
# so via the 'tolerate_same' attribute
#
# tolerate_same should be a list of pairs of the form
# [(idx0, idx1), (idx0, idx2), ...]
# The first element of each pair is the index of a destroyed
# variable.
# The second element of each pair is the index of a different input where
# we will permit exactly the same variable to appear.
# For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
# input is also allowed to appear as the second argument.
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
#print 'tolerated', tolerated
for i, input in enumerate(app.inputs):
if input in root_impact \
and (i not in tolerated or input is not destroyed_result):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
root_clients = set()
for r in root_impact:
assert not [a for a,c in self.clients[r].items() if not c]
root_clients.update([a for a,c in self.clients[r].items() if c])
root_clients.remove(app)
if root_clients:
rval[app] = root_clients
return rval
from collections import defaultdict
import graph
import utils
import toolbox
from utils import AbstractFunctionError
from env import InconsistencyError
class DestroyHandler(toolbox.Bookkeeper):
def __init__(self):
self.map = {}
def on_attach(self, env):
dh = self.map.setdefault(env, DestroyHandlerHelper())
dh.on_attach(env)
def on_detach(self, env):
self.map[env].on_detach(env)
def on_import(self, env, op):
self.map[env].on_import(env, op)
def on_prune(self, env, op):
self.map[env].on_prune(env, op)
def on_change_input(self, env, node, i, r, new_r):
self.map[env].on_change_input(env, node, i, r, new_r)
def validate(self, env):
self.map[env].validate(env)
def orderings(self, env):
return self.map[env].orderings(env)
class DestroyHandlerHelper(toolbox.Bookkeeper):
"""
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
some of their inputs. It does so by tracking dependencies between
data at different stages of the graph and ensuring that
destructive operations are performed after the destroyed data and
all of its views have been processed.
Examples:
- (x += 1) + (x += 1) -> fails because the first += makes the second
invalid
- (a += b) + (c += a) -> succeeds but we have to do c += a first
- (a += b) + (b += c) + (c += a) -> fails because there's a cyclical
dependency (no possible ordering)
This feature allows some optimizations (eg sub += for +) to be applied
safely.
@todo
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
"""
def __init__(self):
self.env = None
def on_attach(self, env):
if self.env is not None:
raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr):
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
def __destroyers(r):
ret = self.destroyers.get(r, {})
ret = ret.keys()
return ret
env.destroyers = __destroyers
env.destroy_handler = self
self.env = env
# For an Op that has a view_map, {output : input it is a view of}
self.parent = {}
# Reverse mapping of parent: {input : outputs that are a view of it}
self.children = defaultdict(set)
# {foundation : {op that destroys it : path }}
# where foundation is a result such that (not self.parent[result])
# and path is a sequence of results such that:
# * path[0] == foundation
# * self.parent[path[i]] == path[i-1]
# * path[-1] == output of the Op that is the Destroyer
self.destroyers = {}
# Cache for the paths
self.paths = {}
### if any of dups, cycles or illegal is not empty, the env is inconsistent
# Set of results that are destroyed more than once.
self.dups = set()
# Set of sequences of results that represent a dependency cycle, i.e.
# [a, ... b, ... c, ... a] if our graph is ((a += b) + (b += c) + (c += a))
self.cycles = set()
# Set of results that have one Op that destroys them but have been marked
# indestructible by the user.
self.illegal = set()
self.seen = set()
toolbox.Bookkeeper.on_attach(self, env)
def on_detach(self, env):
del self.parent
del self.children
del self.destroyers
del self.paths
del self.dups
del self.cycles
del self.illegal
del self.seen
self.env = None
def __path__(self, r):
"""
Returns a path from r to the result that it is ultimately
a view of, i.e. path such that:
- path[-1] == r
- path[i] == parent[path[i+1]]
- parent[path[0]] == None
"""
path = self.paths.get(r, None)
if path:
return path
rval = [r]
r = self.parent.get(r, None) ### ???
while r:
rval.append(r)
r = self.parent.get(r, None)
rval.reverse()
for i, x in enumerate(rval):
self.paths[x] = rval[0:i+1]
return rval
def __views__(self, r):
"""
Returns the set of results (inclusive) such that all the
results in the set are views of r, directly or indirectly.
"""
children = self.children[r]
if not children:
return [r]
else:
rval = [r]
for child in children:
rval += self.__views__(child)
return utils.uniq(rval)
def __users__(self, r):
"""
Returns the outputs of all the ops that use r or a view
of r. In other words, for all ops that have an input that
is r or a view of r, adds their outputs to the set that
is returned.
"""
views = self.__views__(r)
rval = list(r.owner.outputs) if r.owner else [] # set()
for view in views:
for node, i in view.clients: #self.env.clients(view):
if node != 'output':
rval += node.outputs
return utils.uniq(rval)
def __pre__(self, op):
"""
Returns all results that must be computed prior to computing
this node.
"""
rval = set()
if op is None:
return rval
dmap = getattr(op.op, 'destroy_map', {})
dinputs = reduce(list.__add__, dmap.values(), [])
d_found = {}
nd_found = {}
for i, input in enumerate(op.inputs):
# Get the basic result the input is a view of.
path = self.__path__(input)
foundation = path[0]
destroyers = self.destroyers.get(foundation, set())
# Is this op destroying the foundation? If yes,
# all users of the foundation must be computed before
# we overwrite its contents.
if op in destroyers and i in dinputs:
d_found[foundation] = i
users = self.__users__(foundation)
rval.update(users)
else:
nd_found[foundation] = i
rval.update(op.inputs) # obviously
intersection = set(d_found.keys()).intersection(set(nd_found.keys()))
if not intersection:
rval.difference_update(op.outputs) # this op's outputs will always be in the users
else:
allowed = getattr(op.op, 'tolerate_same', [])
for item in intersection:
i, j = d_found[item], nd_found[item]
pair = i, j
if not (op.inputs[i] is op.inputs[j] and (pair in allowed or tuple(reversed(pair)) in allowed)):
break
else:
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
"""
Does a depth-first search to find cycles in the graph of
computation given a directed connection from an op to
its __pre__ set.
@type seq: sequence
@param seq: nodes visited up to now
@param r: current node
If r is found in seq, we have a cycle and it is added to
the set of cycles.
"""
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
pre = self.__pre__(r.owner)
for r2 in pre:
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
"""
Tries to find a cycle containing any of the users of
start. Prior to doing, we remove all existing cycles
containing an user of start from the cycles set. If
just_remove is True, we return immediately after removing the
cycles.
"""
users = set(self.__users__(start))
users.add(start)
for user in users:
for cycle in set(self.cycles):
if user in cycle:
self.cycles.remove(cycle)
if just_remove:
return
for user in users:
self.__detect_cycles_helper__(user, [])
def get_maps(self, node):
"""
@return: (vmap, dmap) where:
- vmap -> {output : [inputs output is a view of]}
- dmap -> {output : [inputs that are destroyed by the node
(and presumably returned as that output)]}
"""
try: _vmap = node.op.view_map
except AttributeError, AbstractFunctionError: _vmap = {}
try: _dmap = node.op.destroy_map
except AttributeError, AbstractFunctionError: _dmap = {}
vmap = {}
for oidx, iidxs in _vmap.items():
if oidx < 0 or oidx >= node.nout:
raise ValueError("In %s.view_map: output index out of range" % node.op, oidx, _vmap)
if any(iidx < 0 or iidx >= node.nin for iidx in iidxs):
raise ValueError("In %s.view_map: input index out of range" % node.op, iidxs, _vmap)
vmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
dmap = {}
for oidx, iidxs in _dmap.items():
if oidx < 0 or oidx >= node.nout:
raise ValueError("In %s.destroy_map: output index out of range" % node.op, oidx, _dmap)
if any(iidx < 0 or iidx >= node.nin for iidx in iidxs):
raise ValueError("In %s.destroy_map: input index out of range" % node.op, iidxs, _dmap)
dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
return vmap, dmap
def on_import(self, env, op):
"""
Recomputes the dependencies and search for inconsistencies given
that we just added an node to the env.
"""
self.seen.add(op)
view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
self.children.setdefault(input, set())
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
for input in destroyed:
path = self.__path__(input)
self.__add_destroyer__(path + [output])
elif views:
if len(views) > 1:
# This is a limitation of DestroyHandler
# TODO: lift it (requires changes everywhere)
raise Exception("Output is a view of too many inputs.")
self.parent[output] = views[0]
for input in views:
self.children[input].add(output)
self.children[output] = set()
for output in op.outputs:
# output has no users and is not in any cycle because it
# is new. We must however check for cycles from the output
# eg if we are importing F in F(a += b, a) we will obtain
# the following cycle: [F.out, +=.out, F.out] because __pre__
# of +=.out, since it is destructive, must contains all the
# users of a including F.out. A cycle not involving F.out
# cannot occur.
self.__detect_cycles_helper__(output, [])
def on_prune(self, env, op):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed a node to the env.
"""
view_map, destroy_map = self.get_maps(op)
if destroy_map:
# Clean up self.destroyers considering that this op is gone.
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
for destroyer in destroyers:
path = destroyer.get(op, [])
if path:
self.__remove_destroyer__(path)
if view_map:
# Clean the children of the inputs if this Op was a view of any of them.
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
for output in op.outputs:
try:
del self.paths[output]
except:
pass
# True means that we are just removing cycles pertaining to this output
# including cycles involving the users of the output (since there should
# be no more users after the op is pruned).
# No new cycles can be added by removing a node.
self.__detect_cycles__(output, True)
# Clean up parents and children
for i, output in enumerate(op.outputs):
try:
self.parent[output]
del self.parent[output]
except:
pass
del self.children[output]
self.seen.remove(op)
def __add_destroyer__(self, path):
"""
Processes the information that path[0] is destroyed by path[-1].owner.
"""
foundation = path[0]
target = path[-1]
node = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(node, path)
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
if len(destroyers) > 1:
self.dups.add(foundation)
# results marked 'indestructible' must not be destroyed.
if getattr(foundation.tag, 'indestructible', False) or isinstance(foundation, graph.Constant):
self.illegal.add(foundation)
def __remove_destroyer__(self, path):
"""
Processes the information that path[0] is no longer destroyed by path[-1].owner.
"""
foundation = path[0]
target = path[-1]
node = target.owner
destroyers = self.destroyers[foundation]
del destroyers[node]
if not destroyers:
if foundation in self.illegal:
self.illegal.remove(foundation)
del self.destroyers[foundation]
elif len(destroyers) == 1 and foundation in self.dups:
self.dups.remove(foundation)
def on_change_input(self, env, node, i, r, new_r):
if node != 'output':
self.on_rewire(env, [(node, i)], r, new_r)
def on_rewire(self, env, clients, r_1, r_2):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being
a list of (node, i) pairs such that node.inputs[i] used to be r_1 and is
now r_2.
"""
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
# All the affected results one level below the replacement.
prev = set()
for op, i in clients:
prev.update(op.outputs)
# Here we look at what destroys r_1, directly or indirectly. Since we
# replace r_1, we must adjust the destroyers. Each destroyer has a path,
# as described in __path__ and __add_destroyer__. Here is the logic to
# adjust a path that contains r_1 at index idx and r_prev at index idx+1.
# * idx == len(path)-1: do nothing
# * r_prev not in prev: do nothing
# * else: concatenate path_2 to the part of the path before r_1.
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
if idx == len(path)-1 or path[idx+1] not in prev:
continue
self.__remove_destroyer__(path)
self.__add_destroyer__(path_2 + path[idx+1:])
# Clean up parents and children
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
if r_2 in inputs:
assert self.parent.get(output, None) == r_1
self.parent[output] = r_2
self.children[r_1].remove(output)
self.children[r_2].add(output)
for view in self.__views__(r_1):
try:
del self.paths[view]
except:
pass
for view in self.__views__(r_2):
try:
del self.paths[view]
except:
pass
# Recompute the cycles from both r_1 and r_2.
self.__detect_cycles__(r_1) # we should really just remove the cycles that have r_1 and a result in prev just before
self.children.setdefault(r_2, set())
self.__detect_cycles__(r_2)
def validate(self, env):
"""
Raises an L{InconsistencyError} on any of the following conditions:
- Some results are destroyed by more than one L{Op}
- There is a cycle of preconditions
- An L{Op} attempts to destroy an indestructible result.
"""
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
raise InconsistencyError("There are cycles: %s" % self.cycles)
elif self.illegal:
raise InconsistencyError("Attempting to destroy indestructible results: %s" % self.illegal)
else:
return True
def orderings(self, env):
"""
Returns a dict of {node : set(nodes that must be computed before it)} according
to L{DestroyHandler}.
In particular, all the users of a destroyed result have priority over the
L{Op} that destroys the result.
"""
self.validate(env)
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
return ords
def view_roots(r):
"""
Utility function that returns the leaves of a search through
consecutive view_map()s.
"""
owner = r.owner
if owner is not None:
try:
view_map = owner.op.view_map
view_map = dict([(owner.outputs[o], i) for o, i in view_map.items()])
except AttributeError:
return [r]
if r in view_map:
answer = []
for i in view_map[r]:
answer += view_roots(owner.inputs[i])
return answer
else:
return [r]
else:
return [r]
...@@ -497,7 +497,7 @@ def clone_with_equiv(i, o, d, missing_input_policy = 'fail', orphan_policy = 'co ...@@ -497,7 +497,7 @@ def clone_with_equiv(i, o, d, missing_input_policy = 'fail', orphan_policy = 'co
return [d[input] for input in i], [d[output] for output in o] return [d[input] for input in i], [d[output] for output in o]
def general_toposort(r_out, deps): def general_toposort(r_out, deps, debug_print = False):
""" """
@note: deps(i) should behave like a pure function (no funny business with @note: deps(i) should behave like a pure function (no funny business with
internal state) internal state)
...@@ -534,11 +534,11 @@ def general_toposort(r_out, deps): ...@@ -534,11 +534,11 @@ def general_toposort(r_out, deps):
sources.append(client) sources.append(client)
if len(rlist) != len(reachable): if len(rlist) != len(reachable):
print '' if debug_print:
print reachable print ''
print rlist print reachable
print rlist
raise 'failed to complete topological sort of given nodes' raise ValueError('graph contains cycles')
return rlist return rlist
...@@ -644,3 +644,26 @@ def as_string(i, o, ...@@ -644,3 +644,26 @@ def as_string(i, o,
return [describe(output) for output in o] return [describe(output) for output in o]
def view_roots(r):
"""
Utility function that returns the leaves of a search through
consecutive view_map()s.
"""
owner = r.owner
if owner is not None:
try:
view_map = owner.op.view_map
view_map = dict([(owner.outputs[o], i) for o, i in view_map.items()])
except AttributeError:
return [r]
if r in view_map:
answer = []
for i in view_map[r]:
answer += view_roots(owner.inputs[i])
return answer
else:
return [r]
else:
return [r]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论