提交 66cb3bf6 authored 作者: Frederic's avatar Frederic

pep8

上级 23230db9
......@@ -4,7 +4,7 @@ from theano.compat import cmp
## {{{ http://code.activestate.com/recipes/578231/ (r1)
# Copyright (c) Oren Tirosh 2012
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
......@@ -22,6 +22,8 @@ from theano.compat import cmp
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
def memodict(f):
""" Memoization decorator for a function taking a single argument """
class memodict(defaultdict):
......@@ -37,10 +39,11 @@ def make_depends():
def depends((a, b)):
""" Returns True if a depends on b """
return (any(bout in a.inputs for bout in b.outputs)
or any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner))
or any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner))
return depends
def make_dependence_cmp():
""" Create a comparator to represent the dependence of nodes in a graph """
......@@ -53,12 +56,15 @@ def make_dependence_cmp():
Returns negative number if b depends on a
Returns 0 otherwise
"""
if depends((a, b)): return 1
if depends((b, a)): return -1
if depends((a, b)):
return 1
if depends((b, a)):
return -1
return 0
return dependence
def reverse_dict(d):
"""Reverses direction of dependence dict
......@@ -78,6 +84,7 @@ def reverse_dict(d):
result[val] = result.get(val, tuple()) + (key, )
return result
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
......@@ -112,6 +119,7 @@ def _toposort(edges):
raise ValueError("Input has cycles")
return L
def posort(l, *cmps):
""" Partially ordered sort with multiple comparators
......@@ -133,9 +141,9 @@ def posort(l, *cmps):
implemented with _toposort """
comes_before = dict((a, set()) for a in l)
comes_after = dict((a, set()) for a in l)
comes_after = dict((a, set()) for a in l)
def add_links(a, b): # b depends on a
def add_links(a, b): # b depends on a
comes_after[a].add(b)
comes_after[a].update(comes_after[b])
for c in comes_before[a]:
......@@ -154,7 +162,7 @@ def posort(l, *cmps):
for cmp in cmps:
for a in l:
for b in l:
if cmp(a, b) < 0: # a wants to come before b
if cmp(a, b) < 0: # a wants to come before b
# if this wouldn't cause a cycle and isn't already known
if not b in comes_before[a] and not b in comes_after[a]:
add_links(a, b)
......@@ -162,6 +170,7 @@ def posort(l, *cmps):
return _toposort(comes_after)
def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators
......@@ -184,6 +193,7 @@ def sort_apply_nodes(inputs, outputs, cmps):
return posort(list_of_nodes(inputs, outputs), *cmps)
def sort_schedule_fn(*cmps):
""" Make a schedule function from comparators
......@@ -192,11 +202,13 @@ def sort_schedule_fn(*cmps):
"""
dependence = make_dependence_cmp()
cmps = (dependence,) + cmps
def schedule(fgraph):
""" Order nodes in a FunctionGraph """
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule
def key_to_cmp(key):
def key_cmp(a, b):
return cmp(key(a), key(b))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论