提交 f493e77a authored 作者: Frederic's avatar Frederic

pep8

上级 c1366d70
...@@ -9,10 +9,12 @@ from theano.gof.op import Op ...@@ -9,10 +9,12 @@ from theano.gof.op import Op
from theano.gof import env from theano.gof import env
from theano.gof import toolbox from theano.gof import toolbox
def as_variable(x): def as_variable(x):
assert isinstance(x, Variable) assert isinstance(x, Variable)
return x return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
...@@ -62,8 +64,9 @@ class TDouble(Type): ...@@ -62,8 +64,9 @@ class TDouble(Type):
tdouble = TDouble() tdouble = TDouble()
def double(name): def double(name):
return Variable(tdouble, None, None, name = name) return Variable(tdouble, None, None, name=name)
class MyOp(Op): class MyOp(Op):
...@@ -87,6 +90,7 @@ class MyOp(Op): ...@@ -87,6 +90,7 @@ class MyOp(Op):
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
out, = out_ out, = out_
out[0] = self.impl(*inputs) out[0] = self.impl(*inputs)
def c_code_cache_version(self): def c_code_cache_version(self):
return () return ()
...@@ -95,6 +99,7 @@ class Unary(MyOp): ...@@ -95,6 +99,7 @@ class Unary(MyOp):
def __init__(self): def __init__(self):
MyOp.__init__(self, 1, self.__class__.__name__) MyOp.__init__(self, 1, self.__class__.__name__)
class Binary(MyOp): class Binary(MyOp):
def __init__(self): def __init__(self):
MyOp.__init__(self, 2, self.__class__.__name__) MyOp.__init__(self, 2, self.__class__.__name__)
...@@ -105,37 +110,45 @@ class Add(Binary): ...@@ -105,37 +110,45 @@ class Add(Binary):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s + %(y)s;" % locals() return "%(z)s = %(x)s + %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return x + y return x + y
add = Add() add = Add()
class Sub(Binary): class Sub(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return -10 # erroneous (most of the time) return -10 # erroneous (most of the time)
sub = Sub() sub = Sub()
class Mul(Binary): class Mul(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s * %(y)s;" % locals() return "%(z)s = %(x)s * %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return x * y return x * y
mul = Mul() mul = Mul()
class Div(Binary): class Div(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
div = Div() div = Div()
def inputs(): def inputs():
x = double('x') x = double('x')
y = double('y') y = double('y')
...@@ -159,6 +172,7 @@ def test_clinker_straightforward(): ...@@ -159,6 +172,7 @@ def test_clinker_straightforward():
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
def test_clinker_literal_inlining(): def test_clinker_literal_inlining():
x, y, z = inputs() x, y, z = inputs()
z = Constant(tdouble, 4.12345678) z = Constant(tdouble, 4.12345678)
...@@ -169,7 +183,8 @@ def test_clinker_literal_inlining(): ...@@ -169,7 +183,8 @@ def test_clinker_literal_inlining():
code = lnk.code_gen() code = lnk.code_gen()
print "=== Code generated ===" print "=== Code generated ==="
print code print code
assert "4.12345678" in code # we expect the number to be inlined assert "4.12345678" in code # we expect the number to be inlined
def test_clinker_single_node(): def test_clinker_single_node():
x, y, z = inputs() x, y, z = inputs()
...@@ -178,6 +193,7 @@ def test_clinker_single_node(): ...@@ -178,6 +193,7 @@ def test_clinker_single_node():
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 7.0) == 9 assert fn(2.0, 7.0) == 9
def test_clinker_dups(): def test_clinker_dups():
# Testing that duplicate inputs are allowed. # Testing that duplicate inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
...@@ -187,6 +203,7 @@ def test_clinker_dups(): ...@@ -187,6 +203,7 @@ def test_clinker_dups():
assert fn(2.0, 2.0) == 4 assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner(): def test_clinker_dups_inner():
# Testing that duplicates are allowed inside the graph # Testing that duplicates are allowed inside the graph
x, y, z = inputs() x, y, z = inputs()
...@@ -196,7 +213,6 @@ def test_clinker_dups_inner(): ...@@ -196,7 +213,6 @@ def test_clinker_dups_inner():
assert fn(1.0, 2.0, 3.0) == 8.0 assert fn(1.0, 2.0, 3.0) == 8.0
###################### ######################
# Test OpWiseCLinker # # Test OpWiseCLinker #
###################### ######################
...@@ -208,9 +224,10 @@ def test_opwiseclinker_straightforward(): ...@@ -208,9 +224,10 @@ def test_opwiseclinker_straightforward():
fn = lnk.make_function() fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0 assert fn(2.0, 2.0, 2.0) == 2.0
def test_opwiseclinker_constant(): def test_opwiseclinker_constant():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x') x = Constant(tdouble, 7.2, name='x')
e = add(mul(x, y), mul(y, z)) e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e])) lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
...@@ -218,13 +235,14 @@ def test_opwiseclinker_constant(): ...@@ -218,13 +235,14 @@ def test_opwiseclinker_constant():
assert res == 15.3 assert res == 15.3
class MyExc(Exception): class MyExc(Exception):
pass pass
def _my_checker(x, y): def _my_checker(x, y):
if x[0] != y[0]: if x[0] != y[0]:
raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]}) raise MyExc("Output mismatch.",
{'performlinker': x[0], 'clinker': y[0]})
################### ###################
...@@ -233,22 +251,27 @@ def _my_checker(x, y): ...@@ -233,22 +251,27 @@ def _my_checker(x, y):
def test_duallinker_straightforward(): def test_duallinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e])) lnk = DualLinker(checker=_my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
assert res == 15.3 assert res == 15.3
def test_duallinker_mismatch(): def test_duallinker_mismatch():
x, y, z = inputs() x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python # sub is correct in C but erroneous in Python
e = sub(mul(x, y), mul(y, z))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
lnk = DualLinker(checker = _my_checker).accept(g) lnk = DualLinker(checker=_my_checker).accept(g)
fn = lnk.make_function() fn = lnk.make_function()
assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good # good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0 # (purposely) wrong # good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0
# (purposely) wrong
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0
try: try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds # this runs OpWiseCLinker and PerformLinker in parallel and feeds
...@@ -268,17 +291,19 @@ class AddFail(Binary): ...@@ -268,17 +291,19 @@ class AddFail(Binary):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, y = inp x, y = inp
z, = out z, = out
fail=sub['fail'] fail = sub['fail']
return """%(z)s = %(x)s + %(y)s; return """%(z)s = %(x)s + %(y)s;
PyErr_SetString(PyExc_RuntimeError, "failing here"); PyErr_SetString(PyExc_RuntimeError, "failing here");
%(fail)s;""" % locals() %(fail)s;""" % locals()
def impl(self, x, y): def impl(self, x, y):
return x + y return x + y
add_fail = AddFail() add_fail = AddFail()
def test_fail_error(): def test_fail_error():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x') x = Constant(tdouble, 7.2, name='x')
e = add_fail(mul(x, y), mul(y, z)) e = add_fail(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e])) lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
...@@ -286,7 +311,5 @@ def test_fail_error(): ...@@ -286,7 +311,5 @@ def test_fail_error():
res = fn(1.5, 3.0) res = fn(1.5, 3.0)
except RuntimeError: except RuntimeError:
print 'Yay, TEST PASSED' print 'Yay, TEST PASSED'
return #test passed return # test passed
assert 0 #test failed assert 0 # test failed
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论