提交 37e5f09c authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add sort_schedule_fn and tests

上级 ec08d469
......@@ -155,3 +155,15 @@ def sort_apply_nodes(inputs, outputs, cmps):
"""
return posort(list_of_nodes(inputs, outputs), *cmps)
def sort_schedule_fn(*cmps):
""" Make a schedule function from comparators
See also:
sort_apply_nodes
"""
cmps = (dependence,) + cmps
def schedule(fgraph):
""" Order nodes in a FunctionGraph """
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule
......@@ -163,3 +163,18 @@ class TestWrapLinker(unittest.TestCase):
fn()
assert nodes == [div, add, mul]
assert o[0].data == 1.5
def test_sort_schedule_fn():
import theano
from theano.gof.sched import sort_schedule_fn, depends
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x[:5]*2, x.T+1).T
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp))
mode = theano.Mode(linker=linker)
f = theano.function((x,), (y,), mode=mode)
nodes = f.maker.linker.make_all()[-1]
for a, b in zip(nodes[:-1], nodes[1:]):
if not depends((b,a)):
assert str(a) < str(b)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论