提交 9ddd9748 authored 作者: Frederic's avatar Frederic

pep8

上级 8e512ce0
...@@ -45,6 +45,7 @@ class MyType(Type): ...@@ -45,6 +45,7 @@ class MyType(Type):
raise ValueError("Invalid value") raise ValueError("Invalid value")
return x return x
class MyOp(Op): class MyOp(Op):
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -81,8 +82,10 @@ class TestOp: ...@@ -81,8 +82,10 @@ class TestOp:
def test_sanity_0(self): def test_sanity_0(self):
r1, r2 = MyType(1)(), MyType(2)() r1, r2 = MyType(1)(), MyType(2)()
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
assert [x for x in node.inputs] == [r1, r2] # Are the inputs what I provided? # Are the inputs what I provided?
assert [x.type for x in node.outputs] == [MyType(3)] # Are the outputs what I expect? assert [x for x in node.inputs] == [r1, r2]
# Are the outputs what I expect?
assert [x.type for x in node.outputs] == [MyType(3)]
assert node.outputs[0].owner is node and node.outputs[0].index == 0 assert node.outputs[0].owner is node and node.outputs[0].index == 0
# validate # validate
...@@ -100,6 +103,7 @@ class TestOp: ...@@ -100,6 +103,7 @@ class TestOp:
rval = f() rval = f()
assert rval == 'test Op no input' assert rval == 'test Op no input'
class TestMakeThunk(unittest.TestCase): class TestMakeThunk(unittest.TestCase):
def test_no_c_code(self): def test_no_c_code(self):
class IncOnePython(Op): class IncOnePython(Op):
...@@ -121,7 +125,6 @@ class TestMakeThunk(unittest.TestCase): ...@@ -121,7 +125,6 @@ class TestMakeThunk(unittest.TestCase):
output, = outputs output, = outputs
output[0] = input + 1 output[0] = input + 1
i = scalar.int32('i') i = scalar.int32('i')
o = IncOnePython()(i) o = IncOnePython()(i)
...@@ -130,11 +133,9 @@ class TestMakeThunk(unittest.TestCase): ...@@ -130,11 +133,9 @@ class TestMakeThunk(unittest.TestCase):
o.owner.op.c_code, o.owner.op.c_code,
o.owner, 'o', ['x'], 'z', {'fail': ''}) o.owner, 'o', ['x'], 'z', {'fail': ''})
storage_map = { storage_map = {i: [numpy.int32(3)],
i: [numpy.int32(3)],
o: [None]} o: [None]}
compute_map = { compute_map = {i: [True],
i: [True],
o: [False]} o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map, thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
...@@ -166,7 +167,6 @@ class TestMakeThunk(unittest.TestCase): ...@@ -166,7 +167,6 @@ class TestMakeThunk(unittest.TestCase):
z, = outputs z, = outputs
return "%(z)s = %(x)s + 1;" % locals() return "%(z)s = %(x)s + 1;" % locals()
i = scalar.int32('i') i = scalar.int32('i')
o = IncOneC()(i) o = IncOneC()(i)
...@@ -175,11 +175,9 @@ class TestMakeThunk(unittest.TestCase): ...@@ -175,11 +175,9 @@ class TestMakeThunk(unittest.TestCase):
o.owner.op.perform, o.owner.op.perform,
o.owner, 0, [None]) o.owner, 0, [None])
storage_map = { storage_map = {i: [numpy.int32(3)],
i: [numpy.int32(3)],
o: [None]} o: [None]}
compute_map = { compute_map = {i: [True],
i: [True],
o: [False]} o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map, thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
...@@ -201,30 +199,33 @@ def test_test_value_python_objects(): ...@@ -201,30 +199,33 @@ def test_test_value_python_objects():
def test_test_value_ndarray(): def test_test_value_ndarray():
x = numpy.zeros((5,5)) x = numpy.zeros((5, 5))
v = op.get_test_value(x) v = op.get_test_value(x)
assert (v == x).all() assert (v == x).all()
def test_test_value_constant(): def test_test_value_constant():
x = T.as_tensor_variable(numpy.zeros((5,5))) x = T.as_tensor_variable(numpy.zeros((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5))) assert numpy.all(v == numpy.zeros((5, 5)))
def test_test_value_shared(): def test_test_value_shared():
x = shared(numpy.zeros((5,5))) x = shared(numpy.zeros((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5))) assert numpy.all(v == numpy.zeros((5, 5)))
def test_test_value_op(): def test_test_value_op():
try: try:
prev_value = config.compute_test_value prev_value = config.compute_test_value
config.compute_test_value = 'raise' config.compute_test_value = 'raise'
x = T.log(numpy.ones((5,5))) x = T.log(numpy.ones((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert numpy.allclose(v, numpy.zeros((5,5))) assert numpy.allclose(v, numpy.zeros((5, 5)))
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
...@@ -244,11 +245,11 @@ def test_get_debug_values_no_debugger(): ...@@ -244,11 +245,11 @@ def test_get_debug_values_no_debugger():
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
def test_get_det_debug_values_ignore(): def test_get_det_debug_values_ignore():
"""get_debug_values should return [] when debugger is ignore """get_debug_values should return [] when debugger is ignore
and some values are missing """ and some values are missing """
prev_value = config.compute_test_value prev_value = config.compute_test_value
try: try:
config.compute_test_value = 'ignore' config.compute_test_value = 'ignore'
...@@ -267,21 +268,21 @@ def test_get_debug_values_success(): ...@@ -267,21 +268,21 @@ def test_get_debug_values_success():
(and the debugger is on)""" (and the debugger is on)"""
prev_value = config.compute_test_value prev_value = config.compute_test_value
for mode in [ 'ignore', 'warn', 'raise' ]: for mode in ['ignore', 'warn', 'raise']:
try: try:
config.compute_test_value = mode config.compute_test_value = mode
x = T.vector() x = T.vector()
x.tag.test_value = numpy.zeros((4,), dtype=config.floatX) x.tag.test_value = numpy.zeros((4,), dtype=config.floatX)
y = numpy.zeros((5,5)) y = numpy.zeros((5, 5))
iters = 0 iters = 0
for x_val, y_val in op.get_debug_values(x, y): for x_val, y_val in op.get_debug_values(x, y):
assert x_val.shape == (4,) assert x_val.shape == (4,)
assert y_val.shape == (5,5) assert y_val.shape == (5, 5)
iters += 1 iters += 1
...@@ -290,6 +291,7 @@ def test_get_debug_values_success(): ...@@ -290,6 +291,7 @@ def test_get_debug_values_success():
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
def test_get_debug_values_exc(): def test_get_debug_values_exc():
"""tests that get_debug_value raises an exception when """tests that get_debug_value raises an exception when
debugger is set to raise and a value is missing """ debugger is set to raise and a value is missing """
...@@ -317,13 +319,14 @@ def test_get_debug_values_exc(): ...@@ -317,13 +319,14 @@ def test_get_debug_values_exc():
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
def test_debug_error_message(): def test_debug_error_message():
"""tests that debug_error_message raises an """tests that debug_error_message raises an
exception when it should.""" exception when it should."""
prev_value = config.compute_test_value prev_value = config.compute_test_value
for mode in [ 'ignore', 'raise' ]: for mode in ['ignore', 'raise']:
try: try:
config.compute_test_value = mode config.compute_test_value = mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论