提交 2c47d726 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

move sorting code to sched.py

上级 1feb6e22
......@@ -1017,125 +1017,3 @@ def list_of_nodes(inputs, outputs):
lambda o: [inp.owner for inp in o.inputs
if inp.owner
and not any(i in inp.owner.outputs for i in inputs)])
## {{{ http://code.activestate.com/recipes/578231/ (r1)
def memodict(f):
""" Memoization decorator for a function taking a single argument """
class memodict(dict):
def __missing__(self, key):
ret = self[key] = f(key)
return ret
return memodict().__getitem__
## end of http://code.activestate.com/recipes/578231/ }}}
@memodict
def depends((a, b)):
return (any(bout in a.inputs for bout in b.outputs)
or any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner))
def dependence(a, b):
""" A cmp function for nodes in a graph - does a depend on b?
Returns positive number if a depends on b
Returns negative number if b depends on a
Returns 0 otherwise
"""
if depends((a, b)): return 1
if depends((b, a)): return -1
return 0
def reverse_dict(d):
""" Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d)
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
"""
result = {}
for key in d:
for val in d[key]:
result[val] = result.get(val, tuple()) + (key, )
return result
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = set((v for v in edges if v not in incoming_edges))
L = []
while S:
n = S.pop()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def posort(l, *cmps):
""" Partially ordered sort with multiple comparators
implemented with _toposort """
comes_before = {a: set() for a in l}
comes_after = {a: set() for a in l}
def add_links(a, b): # b depends on a
comes_after[a].add(b)
comes_after[a].update(comes_after[b])
for c in comes_before[a]:
comes_after[c].update(comes_after[a])
comes_before[b].add(a)
comes_before[b].update(comes_before[a])
for c in comes_after[b]:
comes_before[c].update(comes_before[b])
def check():
""" Tests for cycles in manufactured edges """
for a in l:
for b in l:
assert not(b in comes_after[a] and a in comes_after[b])
for cmp in cmps:
for a in l:
for b in l:
if cmp(a, b) < 0: # a wants to come before b
# if this wouldn't cause a cycle and isn't already known
if not b in comes_before[a] and not b in comes_after[a]:
add_links(a, b)
# check() # debug code
return _toposort(comes_after)
def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators
The following example sorts first by dependence of nodes (this is a
topological sort) and then by lexicographical ordering (nodes that start
with 'E' come before nodes that start with 'I' if there is no dependence.
>>> from theano.gof.graph import sort_apply_nodes, dependence
>>> from theano.tensor import matrix, dot
>>> x = matrix('x')
>>> y = dot(x*2, x+1)
>>> str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
>>> sort_apply_nodes([x], [y], cmps=[dependence, str_cmp])
[Elemwise{add,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{2}),
Elemwise{mul,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{1}),
dot(Elemwise{mul,no_inplace}.0, Elemwise{add,no_inplace}.0)]
"""
return posort(list_of_nodes(inputs, outputs), *cmps)
from graph import list_of_nodes
## {{{ http://code.activestate.com/recipes/578231/ (r1)
def memodict(f):
""" Memoization decorator for a function taking a single argument """
class memodict(dict):
def __missing__(self, key):
ret = self[key] = f(key)
return ret
return memodict().__getitem__
## end of http://code.activestate.com/recipes/578231/ }}}
@memodict
def depends((a, b)):
return (any(bout in a.inputs for bout in b.outputs)
or any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner))
def dependence(a, b):
""" A cmp function for nodes in a graph - does a depend on b?
Returns positive number if a depends on b
Returns negative number if b depends on a
Returns 0 otherwise
"""
if depends((a, b)): return 1
if depends((b, a)): return -1
return 0
def reverse_dict(d):
""" Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d)
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
"""
result = {}
for key in d:
for val in d[key]:
result[val] = result.get(val, tuple()) + (key, )
return result
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = set((v for v in edges if v not in incoming_edges))
L = []
while S:
n = S.pop()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S.add(m)
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def posort(l, *cmps):
""" Partially ordered sort with multiple comparators
implemented with _toposort """
comes_before = {a: set() for a in l}
comes_after = {a: set() for a in l}
def add_links(a, b): # b depends on a
comes_after[a].add(b)
comes_after[a].update(comes_after[b])
for c in comes_before[a]:
comes_after[c].update(comes_after[a])
comes_before[b].add(a)
comes_before[b].update(comes_before[a])
for c in comes_after[b]:
comes_before[c].update(comes_before[b])
def check():
""" Tests for cycles in manufactured edges """
for a in l:
for b in l:
assert not(b in comes_after[a] and a in comes_after[b])
for cmp in cmps:
for a in l:
for b in l:
if cmp(a, b) < 0: # a wants to come before b
# if this wouldn't cause a cycle and isn't already known
if not b in comes_before[a] and not b in comes_after[a]:
add_links(a, b)
# check() # debug code
return _toposort(comes_after)
def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators
The following example sorts first by dependence of nodes (this is a
topological sort) and then by lexicographical ordering (nodes that start
with 'E' come before nodes that start with 'I' if there is no dependence.
>>> from theano.gof.graph import sort_apply_nodes, dependence
>>> from theano.tensor import matrix, dot
>>> x = matrix('x')
>>> y = dot(x*2, x+1)
>>> str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
>>> sort_apply_nodes([x], [y], cmps=[dependence, str_cmp])
[Elemwise{add,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{2}),
Elemwise{mul,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{1}),
dot(Elemwise{mul,no_inplace}.0, Elemwise{add,no_inplace}.0)]
"""
return posort(list_of_nodes(inputs, outputs), *cmps)
......@@ -3,8 +3,7 @@ import unittest
from theano import tensor
from theano.gof.graph import (
Apply, as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable, dependence, sort_apply_nodes, posort,
reverse_dict, _toposort)
is_same_graph, Variable)
from theano.gof.op import Op
from theano.gof.type import Type
......@@ -291,41 +290,3 @@ class TestIsSameGraph(unittest.TestCase):
({y: x, t: z}, True))),
],
debug=False)
def test_dependence():
x = tensor.matrix('x')
y = tensor.dot(x*2, x+1)
nodes = io_toposort([x], [y])
for a, b in zip(nodes[:-1], nodes[1:]):
assert dependence(a, b) <= 0
def test_sort_apply_nodes():
x = tensor.matrix('x')
y = tensor.dot(x*2, x+1)
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
nodes = sort_apply_nodes([x], [y], cmps=[str_cmp])
for a, b in zip(nodes[:-1], nodes[1:]):
assert str(a) <= str(b)
def test_reverse_dict():
d = {'a': (1, 2), 'b': (2, 3), 'c':()}
assert reverse_dict(d) == {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
def test__toposort():
edges = {1: {4, 6, 7}, 2: {4, 6, 7}, 3: {5, 7}, 4: {6, 7}, 5: {7}}
order = _toposort(edges)
assert not any(a in edges.get(b, ()) for i, a in enumerate(order)
for b in order[i:])
def test_posort_easy():
nodes = "asdfghjkl"
cmp = lambda a,b: -1 if a<b else 1 if a>b else 0
assert posort(nodes, cmp) == list("adfghjkls")
def test_posort():
l = range(1,20)
cmps = [lambda a,b: a%10 - b%10, lambda a, b: (a/10)%2 - (b/10)%2,
lambda a,b: a-b]
assert posort(l, *cmps) == \
[10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19]
from theano.gof.sched import (dependence, sort_apply_nodes, reverse_dict,
_toposort, posort)
import theano
from theano import tensor
from theano.gof.graph import io_toposort
def test_dependence():
x = tensor.matrix('x')
y = tensor.dot(x*2, x+1)
nodes = io_toposort([x], [y])
for a, b in zip(nodes[:-1], nodes[1:]):
assert dependence(a, b) <= 0
def test_sort_apply_nodes():
x = tensor.matrix('x')
y = tensor.dot(x*2, x+1)
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
nodes = sort_apply_nodes([x], [y], cmps=[str_cmp])
for a, b in zip(nodes[:-1], nodes[1:]):
assert str(a) <= str(b)
def test_reverse_dict():
d = {'a': (1, 2), 'b': (2, 3), 'c':()}
assert reverse_dict(d) == {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
def test__toposort():
edges = {1: {4, 6, 7}, 2: {4, 6, 7}, 3: {5, 7}, 4: {6, 7}, 5: {7}}
order = _toposort(edges)
assert not any(a in edges.get(b, ()) for i, a in enumerate(order)
for b in order[i:])
def test_posort_easy():
nodes = "asdfghjkl"
cmp = lambda a,b: -1 if a<b else 1 if a>b else 0
assert posort(nodes, cmp) == list("adfghjkls")
def test_posort():
l = range(1,20)
cmps = [lambda a,b: a%10 - b%10, lambda a, b: (a/10)%2 - (b/10)%2,
lambda a,b: a-b]
assert posort(l, *cmps) == \
[10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论