提交 632341d4 authored 作者: nouiz's avatar nouiz

Merge pull request #914 from mrocklin/fgraph-order

Add sort_apply_nodes function to gof/sched.py
...@@ -553,7 +553,9 @@ class FunctionGraph(utils.object2): ...@@ -553,7 +553,9 @@ class FunctionGraph(utils.object2):
# 1-element graphs. # 1-element graphs.
return list(self.apply_nodes) return list(self.apply_nodes)
fg = self fg = self
ords = self.orderings() ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords) order = graph.io_toposort(fg.inputs, fg.outputs, ords)
return order return order
......
...@@ -1029,3 +1029,11 @@ def view_roots(r): ...@@ -1029,3 +1029,11 @@ def view_roots(r):
return [r] return [r]
else: else:
return [r] return [r]
def list_of_nodes(inputs, outputs):
""" Return the apply nodes of the graph between inputs and outputs """
return stack_search(
deque([o.owner for o in outputs]),
lambda o: [inp.owner for inp in o.inputs
if inp.owner
and not any(i in inp.owner.outputs for i in inputs)])
from graph import list_of_nodes
## {{{ http://code.activestate.com/recipes/578231/ (r1)
def memodict(f):
""" Memoization decorator for a function taking a single argument """
class memodict(dict):
def __missing__(self, key):
ret = self[key] = f(key)
return ret
return memodict().__getitem__
## end of http://code.activestate.com/recipes/578231/ }}}
def make_depends():
@memodict
def depends((a, b)):
""" Returns True if a depends on b """
return (any(bout in a.inputs for bout in b.outputs)
or any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner))
return depends
def make_dependence_cmp():
""" Create a comparator to represent the dependence of nodes in a graph """
depends = make_depends()
def dependence(a, b):
""" A cmp function for nodes in a graph - does a depend on b?
Returns positive number if a depends on b
Returns negative number if b depends on a
Returns 0 otherwise
"""
if depends((a, b)): return 1
if depends((b, a)): return -1
return 0
return dependence
def reverse_dict(d):
""" Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d)
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
"""
result = {}
for key in d:
for val in d[key]:
result[val] = result.get(val, tuple()) + (key, )
return result
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: {2, 3}, 2: (3, )})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items())
S = set((v for v in edges if v not in incoming_edges))
L = []
while S:
n = S.pop()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def posort(l, *cmps):
""" Partially ordered sort with multiple comparators
Given a list of comparators order the elements in l so that the comparators
are satisfied as much as possible giving precedence to earlier comparators.
inputs:
l - an iterable of nodes in a graph
cmps - a sequence of comparator functions that describe which nodes
should come before which others
outputs:
a list of nodes which satisfy the comparators as much as possible.
>>> lower_tens = lambda a, b: a/10 - b/10 # prefer lower numbers div 10
>>> prefer evens = lambda a, b: a%2 - b%2 # prefer even numbers
>>> posort(range(20), lower_tens, prefer_evens)
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
implemented with _toposort """
comes_before = dict((a, set()) for a in l)
comes_after = dict((a, set()) for a in l)
def add_links(a, b): # b depends on a
comes_after[a].add(b)
comes_after[a].update(comes_after[b])
for c in comes_before[a]:
comes_after[c].update(comes_after[a])
comes_before[b].add(a)
comes_before[b].update(comes_before[a])
for c in comes_after[b]:
comes_before[c].update(comes_before[b])
def check():
""" Tests for cycles in manufactured edges """
for a in l:
for b in l:
assert not(b in comes_after[a] and a in comes_after[b])
for cmp in cmps:
for a in l:
for b in l:
if cmp(a, b) < 0: # a wants to come before b
# if this wouldn't cause a cycle and isn't already known
if not b in comes_before[a] and not b in comes_after[a]:
add_links(a, b)
# check() # debug code
return _toposort(comes_after)
def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators
The following example sorts first by dependence of nodes (this is a
topological sort) and then by lexicographical ordering (nodes that start
with 'E' come before nodes that start with 'I' if there is no dependence.
>>> from theano.gof.graph import sort_apply_nodes, dependence
>>> from theano.tensor import matrix, dot
>>> x = matrix('x')
>>> y = dot(x*2, x+1)
>>> str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
>>> sort_apply_nodes([x], [y], cmps=[dependence, str_cmp])
[Elemwise{add,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{2}),
Elemwise{mul,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{1}),
dot(Elemwise{mul,no_inplace}.0, Elemwise{add,no_inplace}.0)]
"""
return posort(list_of_nodes(inputs, outputs), *cmps)
from theano.gof.sched import (make_dependence_cmp, sort_apply_nodes,
reverse_dict, _toposort, posort)
import theano
from theano import tensor
from theano.gof.graph import io_toposort
def test_dependence():
dependence = make_dependence_cmp()
x = tensor.matrix('x')
y = tensor.dot(x*2, x+1)
nodes = io_toposort([x], [y])
for a, b in zip(nodes[:-1], nodes[1:]):
assert dependence(a, b) <= 0
def test_sort_apply_nodes():
x = tensor.matrix('x')
y = tensor.dot(x*2, x+1)
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
nodes = sort_apply_nodes([x], [y], cmps=[str_cmp])
for a, b in zip(nodes[:-1], nodes[1:]):
assert str(a) <= str(b)
def test_reverse_dict():
d = {'a': (1, 2), 'b': (2, 3), 'c':()}
assert reverse_dict(d) == {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
def test__toposort():
edges = {1: set((4, 6, 7)), 2: set((4, 6, 7)),
3: set((5, 7)), 4: set((6, 7)), 5: set((7,))}
order = _toposort(edges)
assert not any(a in edges.get(b, ()) for i, a in enumerate(order)
for b in order[i:])
def test_posort_easy():
nodes = "asdfghjkl"
cmp = lambda a,b: -1 if a<b else 1 if a>b else 0
assert posort(nodes, cmp) == list("adfghjkls")
def test_posort():
l = range(1,20)
cmps = [lambda a,b: a%10 - b%10, lambda a, b: (a/10)%2 - (b/10)%2,
lambda a,b: a-b]
assert posort(l, *cmps) == \
[10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论