提交 f493e77a authored 作者: Frederic's avatar Frederic

pep8

上级 c1366d70
......@@ -9,10 +9,12 @@ from theano.gof.op import Op
from theano.gof import env
from theano.gof import toolbox
def as_variable(x):
assert isinstance(x, Variable)
return x
class TDouble(Type):
def filter(self, data):
return float(data)
......@@ -62,8 +64,9 @@ class TDouble(Type):
tdouble = TDouble()
def double(name):
return Variable(tdouble, None, None, name = name)
return Variable(tdouble, None, None, name=name)
class MyOp(Op):
......@@ -87,6 +90,7 @@ class MyOp(Op):
def perform(self, node, inputs, out_):
out, = out_
out[0] = self.impl(*inputs)
def c_code_cache_version(self):
return ()
......@@ -95,6 +99,7 @@ class Unary(MyOp):
def __init__(self):
MyOp.__init__(self, 1, self.__class__.__name__)
class Binary(MyOp):
def __init__(self):
MyOp.__init__(self, 2, self.__class__.__name__)
......@@ -105,37 +110,45 @@ class Add(Binary):
x, y = inp
z, = out
return "%(z)s = %(x)s + %(y)s;" % locals()
def impl(self, x, y):
return x + y
add = Add()
class Sub(Binary):
def c_code(self, node, name, inp, out, sub):
x, y = inp
z, = out
return "%(z)s = %(x)s - %(y)s;" % locals()
def impl(self, x, y):
return -10 # erroneous (most of the time)
return -10 # erroneous (most of the time)
sub = Sub()
class Mul(Binary):
def c_code(self, node, name, inp, out, sub):
x, y = inp
z, = out
return "%(z)s = %(x)s * %(y)s;" % locals()
def impl(self, x, y):
return x * y
mul = Mul()
class Div(Binary):
def c_code(self, node, name, inp, out, sub):
x, y = inp
z, = out
return "%(z)s = %(x)s / %(y)s;" % locals()
def impl(self, x, y):
return x / y
div = Div()
def inputs():
x = double('x')
y = double('y')
......@@ -159,6 +172,7 @@ def test_clinker_straightforward():
fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0
def test_clinker_literal_inlining():
x, y, z = inputs()
z = Constant(tdouble, 4.12345678)
......@@ -169,7 +183,8 @@ def test_clinker_literal_inlining():
code = lnk.code_gen()
print "=== Code generated ==="
print code
assert "4.12345678" in code # we expect the number to be inlined
assert "4.12345678" in code # we expect the number to be inlined
def test_clinker_single_node():
x, y, z = inputs()
......@@ -178,6 +193,7 @@ def test_clinker_single_node():
fn = lnk.make_function()
assert fn(2.0, 7.0) == 9
def test_clinker_dups():
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
......@@ -187,6 +203,7 @@ def test_clinker_dups():
assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner():
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
......@@ -196,7 +213,6 @@ def test_clinker_dups_inner():
assert fn(1.0, 2.0, 3.0) == 8.0
######################
# Test OpWiseCLinker #
######################
......@@ -208,9 +224,10 @@ def test_opwiseclinker_straightforward():
fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0
def test_opwiseclinker_constant():
x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x')
x = Constant(tdouble, 7.2, name='x')
e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function()
......@@ -218,13 +235,14 @@ def test_opwiseclinker_constant():
assert res == 15.3
class MyExc(Exception):
pass
def _my_checker(x, y):
if x[0] != y[0]:
raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
raise MyExc("Output mismatch.",
{'performlinker': x[0], 'clinker': y[0]})
###################
......@@ -233,22 +251,27 @@ def _my_checker(x, y):
def test_duallinker_straightforward():
x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e]))
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker=_my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0)
assert res == 15.3
def test_duallinker_mismatch():
x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
# sub is correct in C but erroneous in Python
e = sub(mul(x, y), mul(y, z))
g = Env([x, y, z], [e])
lnk = DualLinker(checker = _my_checker).accept(g)
lnk = DualLinker(checker=_my_checker).accept(g)
fn = lnk.make_function()
assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0 # (purposely) wrong
# good
assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0
# good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0
# (purposely) wrong
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0
try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds
......@@ -268,17 +291,19 @@ class AddFail(Binary):
def c_code(self, node, name, inp, out, sub):
x, y = inp
z, = out
fail=sub['fail']
fail = sub['fail']
return """%(z)s = %(x)s + %(y)s;
PyErr_SetString(PyExc_RuntimeError, "failing here");
%(fail)s;""" % locals()
def impl(self, x, y):
return x + y
add_fail = AddFail()
def test_fail_error():
x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x')
x = Constant(tdouble, 7.2, name='x')
e = add_fail(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function()
......@@ -286,7 +311,5 @@ def test_fail_error():
res = fn(1.5, 3.0)
except RuntimeError:
print 'Yay, TEST PASSED'
return #test passed
assert 0 #test failed
return # test passed
assert 0 # test failed
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论