提交 253add96 authored 作者: James Bergstra's avatar James Bergstra

pep8 link tests

上级 c0fd55c9
import unittest
from theano.gof import graph from theano.gof import graph
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
...@@ -8,26 +9,25 @@ from theano.gof import toolbox ...@@ -8,26 +9,25 @@ from theano.gof import toolbox
from theano.gof.link import * from theano.gof.link import *
#from _test_variable import Double
def as_variable(x): def as_variable(x):
assert isinstance(x, Variable) assert isinstance(x, Variable)
return x return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
tdouble = TDouble() tdouble = TDouble()
def double(name): def double(name):
return Variable(tdouble, None, None, name = name) return Variable(tdouble, None, None, name=name)
class MyOp(Op): class MyOp(Op):
def __init__(self, nin, name, impl=None):
def __init__(self, nin, name, impl = None):
self.nin = nin self.nin = nin
self.name = name self.name = name
if impl: if impl:
...@@ -54,11 +54,12 @@ sub = MyOp(2, 'Sub', lambda x, y: x - y) ...@@ -54,11 +54,12 @@ sub = MyOp(2, 'Sub', lambda x, y: x - y)
mul = MyOp(2, 'Mul', lambda x, y: x * y) mul = MyOp(2, 'Mul', lambda x, y: x * y)
div = MyOp(2, 'Div', lambda x, y: x / y) div = MyOp(2, 'Div', lambda x, y: x / y)
def notimpl(self, x): def notimpl(self, x):
raise NotImplementedError() raise NotImplementedError()
raise_err = MyOp(1, 'RaiseErr', notimpl)
raise_err = MyOp(1, 'RaiseErr', notimpl)
def inputs(): def inputs():
...@@ -67,17 +68,18 @@ def inputs(): ...@@ -67,17 +68,18 @@ def inputs():
z = double('z') z = double('z')
return x, y, z return x, y, z
def perform_linker(env): def perform_linker(env):
lnk = PerformLinker().accept(env) lnk = PerformLinker().accept(env)
return lnk return lnk
def Env(inputs, outputs): def Env(inputs, outputs):
e = env.Env(inputs, outputs) e = env.Env(inputs, outputs)
return e return e
class TestPerformLinker: class TestPerformLinker(unittest.TestCase):
def test_thunk(self): def test_thunk(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
...@@ -107,34 +109,37 @@ class TestPerformLinker: ...@@ -107,34 +109,37 @@ class TestPerformLinker:
def test_input_dependency0(self): def test_input_dependency0(self):
x, y, z = inputs() x, y, z = inputs()
a,d = add(x,y), div(x,y) a, d = add(x, y), div(x, y)
e = mul(a,d) e = mul(a, d)
fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function()
assert fn(1.0,2.0,9.0) == 4.5 assert fn(1.0, 2.0, 9.0) == 4.5
def test_skiphole(self): def test_skiphole(self):
x,y,z = inputs() x, y, z = inputs()
a = add(x,y) a = add(x, y)
r = raise_err(a) r = raise_err(a)
e = add(r,a) e = add(r, a)
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y, r], [e]))).make_function()
assert fn(1.0,2.0,4.5) == 7.5 assert fn(1.0, 2.0, 4.5) == 7.5
def wrap_linker(env, linkers, wrapper): def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env) lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk return lnk
class TestWrapLinker:
class TestWrapLinker(unittest.TestCase):
def test_0(self): def test_0(self):
nodes = [] nodes = []
def wrap(i, node, th): def wrap(i, node, th):
nodes.append(node.op) nodes.append(node.op)
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk() fn, i, o = wrap_linker(
Env([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
...@@ -143,20 +148,18 @@ class TestWrapLinker: ...@@ -143,20 +148,18 @@ class TestWrapLinker:
def test_1(self): def test_1(self):
nodes = [] nodes = []
def wrap(i, node, th): def wrap(i, node, th):
nodes.append(node.op) nodes.append(node.op)
th() th()
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk() fn, i, o = wrap_linker(
Env([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
assert nodes == [div, add, mul] assert nodes == [div, add, mul]
assert o[0].data == 1.5 assert o[0].data == 1.5
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论