提交 17429f38 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/sched.py

上级 b70504c9
...@@ -26,7 +26,10 @@ from theano.compat import cmp ...@@ -26,7 +26,10 @@ from theano.compat import cmp
def memodict(f): def memodict(f):
""" Memoization decorator for a function taking a single argument """ """
Memoization decorator for a function taking a single argument.
"""
class memodict(defaultdict): class memodict(defaultdict):
def __missing__(self, key): def __missing__(self, key):
ret = self[key] = f(key) ret = self[key] = f(key)
...@@ -39,7 +42,10 @@ def memodict(f): ...@@ -39,7 +42,10 @@ def memodict(f):
def make_depends(): def make_depends():
@memodict @memodict
def depends(pair): def depends(pair):
""" Returns True if a depends on b """ """
Returns True if a depends on b.
"""
a, b = pair a, b = pair
return (any(bout in a.inputs for bout in b.outputs) or return (any(bout in a.inputs for bout in b.outputs) or
any(depends((ainp.owner, b)) for ainp in a.inputs any(depends((ainp.owner, b)) for ainp in a.inputs
...@@ -48,16 +54,22 @@ def make_depends(): ...@@ -48,16 +54,22 @@ def make_depends():
def make_dependence_cmp(): def make_dependence_cmp():
""" Create a comparator to represent the dependence of nodes in a graph """ """
Create a comparator to represent the dependence of nodes in a graph.
"""
depends = make_depends() depends = make_depends()
def dependence(a, b): def dependence(a, b):
""" A cmp function for nodes in a graph - does a depend on b? """
A cmp function for nodes in a graph - does a depend on b?
Returns
-------
int
Positive number if a depends on b, negative number
if b depends on a, 0 otherwise.
Returns positive number if a depends on b
Returns negative number if b depends on a
Returns 0 otherwise
""" """
if depends((a, b)): if depends((a, b)):
return 1 return 1
...@@ -69,17 +81,22 @@ def make_dependence_cmp(): ...@@ -69,17 +81,22 @@ def make_dependence_cmp():
def reverse_dict(d): def reverse_dict(d):
"""Reverses direction of dependence dict """
Reverses direction of dependence dict.
Notes
-----
dict order is not deterministic. As we iterate on the
input dict, it makes the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
Examples
--------
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) >>> reverse_dict(d)
{1: ('a',), 2: ('a', 'b'), 3: ('b',)} {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
""" """
result = {} result = {}
for key in d: for key in d:
...@@ -89,21 +106,32 @@ def reverse_dict(d): ...@@ -89,21 +106,32 @@ def reverse_dict(d):
def _toposort(edges): def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices) """
Topological sort algorithm by Kahn [1] - O(nodes + vertices).
inputs: Parameters
edges - a dict of the form {a: {b, c}} where b and c depend on a ----------
outputs: edges
L - an ordered list of nodes that satisfy the dependencies of edges A dict of the form {a: {b, c}} where b and c depend on a.
>>> _toposort({1: {2, 3}, 2: (3, )}) Returns
[1, 2, 3] -------
L : list
An ordered list of nodes that satisfy the dependencies of edges.
Closely follows the wikipedia page [2] Closely follows the wikipedia page [2]
References
----------
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks", [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
Examples
--------
>>> _toposort({1: {2, 3}, 2: (3, )})
[1, 2, 3]
""" """
incoming_edges = reverse_dict(edges) incoming_edges = reverse_dict(edges)
incoming_edges = dict((k, set(val)) incoming_edges = dict((k, set(val))
...@@ -125,25 +153,38 @@ def _toposort(edges): ...@@ -125,25 +153,38 @@ def _toposort(edges):
def posort(l, *cmps): def posort(l, *cmps):
""" Partially ordered sort with multiple comparators """
Partially ordered sort with multiple comparators.
Given a list of comparators order the elements in l so that the comparators
are satisfied as much as possible giving precedence to earlier comparators. Given a list of comparators, orders the elements in l so that the
comparators are satisfied as much as possible giving precedence to
inputs: earlier comparators.
l - an iterable of nodes in a graph
cmps - a sequence of comparator functions that describe which nodes Parameters
should come before which others ----------
l
outputs: An iterable of nodes in a graph.
a list of nodes which satisfy the comparators as much as possible. cmps
A sequence of comparator functions that describe which nodes should
come before which others.
Returns
-------
list
A list of nodes which satisfy the comparators as much as possible.
Notes
-----
Implemented with _toposort.
Examples
--------
>>> lower_tens = lambda a, b: a/10 - b/10 # prefer lower numbers div 10 >>> lower_tens = lambda a, b: a/10 - b/10 # prefer lower numbers div 10
>>> prefer evens = lambda a, b: a%2 - b%2 # prefer even numbers >>> prefer evens = lambda a, b: a%2 - b%2 # prefer even numbers
>>> posort(list(range(20)), lower_tens, prefer_evens) >>> posort(list(range(20)), lower_tens, prefer_evens)
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15] [0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
implemented with _toposort """ """
comes_before = dict((a, set()) for a in l) comes_before = dict((a, set()) for a in l)
comes_after = dict((a, set()) for a in l) comes_after = dict((a, set()) for a in l)
...@@ -158,7 +199,10 @@ def posort(l, *cmps): ...@@ -158,7 +199,10 @@ def posort(l, *cmps):
comes_before[c].update(comes_before[b]) comes_before[c].update(comes_before[b])
def check(): def check():
""" Tests for cycles in manufactured edges """ """
Tests for cycles in manufactured edges.
"""
for a in l: for a in l:
for b in l: for b in l:
assert not(b in comes_after[a] and a in comes_after[b]) assert not(b in comes_after[a] and a in comes_after[b])
...@@ -176,12 +220,15 @@ def posort(l, *cmps): ...@@ -176,12 +220,15 @@ def posort(l, *cmps):
def sort_apply_nodes(inputs, outputs, cmps): def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators """
Order a graph of apply nodes according to a list of comparators.
The following example sorts first by dependence of nodes (this is a The following example sorts first by dependence of nodes (this is a
topological sort) and then by lexicographical ordering (nodes that start topological sort) and then by lexicographical ordering (nodes that start
with 'E' come before nodes that start with 'I' if there is no dependence. with 'E' come before nodes that start with 'I' if there is no dependence.
Examples
--------
>>> from theano.gof.graph import sort_apply_nodes, dependence >>> from theano.gof.graph import sort_apply_nodes, dependence
>>> from theano.tensor import matrix, dot >>> from theano.tensor import matrix, dot
>>> x = matrix('x') >>> x = matrix('x')
...@@ -193,22 +240,28 @@ def sort_apply_nodes(inputs, outputs, cmps): ...@@ -193,22 +240,28 @@ def sort_apply_nodes(inputs, outputs, cmps):
Elemwise{mul,no_inplace}(x, InplaceDimShuffle{x,x}.0), Elemwise{mul,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{1}), InplaceDimShuffle{x,x}(TensorConstant{1}),
dot(Elemwise{mul,no_inplace}.0, Elemwise{add,no_inplace}.0)] dot(Elemwise{mul,no_inplace}.0, Elemwise{add,no_inplace}.0)]
"""
"""
return posort(list_of_nodes(inputs, outputs), *cmps) return posort(list_of_nodes(inputs, outputs), *cmps)
def sort_schedule_fn(*cmps): def sort_schedule_fn(*cmps):
""" Make a schedule function from comparators """
Make a schedule function from comparators.
See Also
--------
sort_apply_nodes
See also:
sort_apply_nodes
""" """
dependence = make_dependence_cmp() dependence = make_dependence_cmp()
cmps = (dependence,) + cmps cmps = (dependence,) + cmps
def schedule(fgraph): def schedule(fgraph):
""" Order nodes in a FunctionGraph """ """
Order nodes in a FunctionGraph.
"""
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps) return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule return schedule
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论