提交 253add96 authored 作者: James Bergstra's avatar James Bergstra

pep8 link tests

上级 c0fd55c9
import unittest
from theano.gof import graph
from theano.gof.graph import Variable, Apply, Constant
......@@ -8,26 +9,25 @@ from theano.gof import toolbox
from theano.gof.link import *
#from _test_variable import Double
def as_variable(x):
assert isinstance(x, Variable)
return x
class TDouble(Type):
def filter(self, data):
return float(data)
tdouble = TDouble()
def double(name):
return Variable(tdouble, None, None, name = name)
return Variable(tdouble, None, None, name=name)
class MyOp(Op):
def __init__(self, nin, name, impl = None):
def __init__(self, nin, name, impl=None):
self.nin = nin
self.name = name
if impl:
......@@ -54,11 +54,12 @@ sub = MyOp(2, 'Sub', lambda x, y: x - y)
mul = MyOp(2, 'Mul', lambda x, y: x * y)
div = MyOp(2, 'Div', lambda x, y: x / y)
def notimpl(self, x):
raise NotImplementedError()
raise_err = MyOp(1, 'RaiseErr', notimpl)
raise_err = MyOp(1, 'RaiseErr', notimpl)
def inputs():
......@@ -67,17 +68,18 @@ def inputs():
z = double('z')
return x, y, z
def perform_linker(env):
lnk = PerformLinker().accept(env)
return lnk
def Env(inputs, outputs):
e = env.Env(inputs, outputs)
return e
class TestPerformLinker:
class TestPerformLinker(unittest.TestCase):
def test_thunk(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
......@@ -107,34 +109,37 @@ class TestPerformLinker:
def test_input_dependency0(self):
x, y, z = inputs()
a,d = add(x,y), div(x,y)
e = mul(a,d)
a, d = add(x, y), div(x, y)
e = mul(a, d)
fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function()
assert fn(1.0,2.0,9.0) == 4.5
assert fn(1.0, 2.0, 9.0) == 4.5
def test_skiphole(self):
x,y,z = inputs()
a = add(x,y)
x, y, z = inputs()
a = add(x, y)
r = raise_err(a)
e = add(r,a)
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
assert fn(1.0,2.0,4.5) == 7.5
e = add(r, a)
fn = perform_linker(Env(*graph.clone([x, y, r], [e]))).make_function()
assert fn(1.0, 2.0, 4.5) == 7.5
def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk
class TestWrapLinker:
class TestWrapLinker(unittest.TestCase):
def test_0(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk()
fn, i, o = wrap_linker(
Env([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
......@@ -143,20 +148,18 @@ class TestWrapLinker:
def test_1(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
th()
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk()
fn, i, o = wrap_linker(
Env([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1
i[1].data = 2
fn()
assert nodes == [div, add, mul]
assert o[0].data == 1.5
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论