提交 25ef10c8 authored 作者: Frederic's avatar Frederic

pep8

上级 0e1102f1
......@@ -114,7 +114,8 @@ class TestPerformLinker(unittest.TestCase):
x, y, z = inputs()
a, d = add(x, y), div(x, y)
e = mul(a, d)
fn = perform_linker(FunctionGraph(*graph.clone([x, y, a], [e]))).make_function()
fn = perform_linker(FunctionGraph(*graph.clone([x, y, a],
[e]))).make_function()
assert fn(1.0, 2.0, 9.0) == 4.5
def test_skiphole(self):
......@@ -122,7 +123,8 @@ class TestPerformLinker(unittest.TestCase):
a = add(x, y)
r = raise_err(a)
e = add(r, a)
fn = perform_linker(FunctionGraph(*graph.clone([x, y, r], [e]))).make_function()
fn = perform_linker(FunctionGraph(*graph.clone([x, y, r],
[e]))).make_function()
assert fn(1.0, 2.0, 4.5) == 7.5
......@@ -141,8 +143,8 @@ class TestWrapLinker(unittest.TestCase):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
......@@ -159,20 +161,21 @@ class TestWrapLinker(unittest.TestCase):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
FunctionGraph([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
assert nodes == [div, add, mul]
assert o[0].data == 1.5
def test_sort_schedule_fn():
import theano
from theano.gof.sched import sort_schedule_fn, make_depends
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x[:5]*2, x.T+1).T
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp))
mode = theano.Mode(linker=linker)
f = theano.function((x,), (y,), mode=mode)
......@@ -180,7 +183,7 @@ def test_sort_schedule_fn():
nodes = f.maker.linker.make_all()[-1]
depends = make_depends()
for a, b in zip(nodes[:-1], nodes[1:]):
if not depends((b,a)):
if not depends((b, a)):
assert str(a) < str(b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论