提交 8f2d3dc5 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

remove global dependence and depends functions

上级 1d97639c
...@@ -10,32 +10,19 @@ def memodict(f): ...@@ -10,32 +10,19 @@ def memodict(f):
return memodict().__getitem__ return memodict().__getitem__
## end of http://code.activestate.com/recipes/578231/ }}} ## end of http://code.activestate.com/recipes/578231/ }}}
@memodict def make_depends():
def depends((a, b)): @memodict
def depends((a, b)):
""" Returns True if a depends on b """ """ Returns True if a depends on b """
return (any(bout in a.inputs for bout in b.outputs) return (any(bout in a.inputs for bout in b.outputs)
or any(depends((ainp.owner, b)) for ainp in a.inputs or any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner)) if ainp.owner))
return depends
def dependence(a, b):
""" A cmp function for nodes in a graph - does a depend on b?
Returns positive number if a depends on b
Returns negative number if b depends on a
Returns 0 otherwise
"""
if depends((a, b)): return 1
if depends((b, a)): return -1
return 0
def make_dependence_cmp(): def make_dependence_cmp():
""" Create a comparator to represent the dependence of nodes in a graph """ """ Create a comparator to represent the dependence of nodes in a graph """
@memodict
def depends((a, b)): depends = make_depends()
""" Returns True if a depends on b """
return (any(bout in a.inputs for bout in b.outputs)
or any(depends((ainp.owner, b)) for ainp in a.inputs
if ainp.owner))
def dependence(a, b): def dependence(a, b):
""" A cmp function for nodes in a graph - does a depend on b? """ A cmp function for nodes in a graph - does a depend on b?
......
from theano.gof.sched import (dependence, sort_apply_nodes, reverse_dict, from theano.gof.sched import (make_dependence_cmp, sort_apply_nodes,
_toposort, posort) reverse_dict, _toposort, posort)
import theano import theano
from theano import tensor from theano import tensor
from theano.gof.graph import io_toposort from theano.gof.graph import io_toposort
def test_dependence(): def test_dependence():
dependence = make_dependence_cmp()
x = tensor.matrix('x') x = tensor.matrix('x')
y = tensor.dot(x*2, x+1) y = tensor.dot(x*2, x+1)
nodes = io_toposort([x], [y]) nodes = io_toposort([x], [y])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论