提交 9ddd9748 authored 作者: Frederic's avatar Frederic

pep8

上级 8e512ce0
...@@ -45,6 +45,7 @@ class MyType(Type): ...@@ -45,6 +45,7 @@ class MyType(Type):
raise ValueError("Invalid value") raise ValueError("Invalid value")
return x return x
class MyOp(Op): class MyOp(Op):
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -81,14 +82,16 @@ class TestOp: ...@@ -81,14 +82,16 @@ class TestOp:
def test_sanity_0(self): def test_sanity_0(self):
r1, r2 = MyType(1)(), MyType(2)() r1, r2 = MyType(1)(), MyType(2)()
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
assert [x for x in node.inputs] == [r1, r2] # Are the inputs what I provided? # Are the inputs what I provided?
assert [x.type for x in node.outputs] == [MyType(3)] # Are the outputs what I expect? assert [x for x in node.inputs] == [r1, r2]
# Are the outputs what I expect?
assert [x.type for x in node.outputs] == [MyType(3)]
assert node.outputs[0].owner is node and node.outputs[0].index == 0 assert node.outputs[0].owner is node and node.outputs[0].index == 0
# validate # validate
def test_validate(self): def test_validate(self):
try: try:
MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances
raise Exception("Expected an exception") raise Exception("Expected an exception")
except Exception, e: except Exception, e:
if str(e) != "Error 1": if str(e) != "Error 1":
...@@ -100,6 +103,7 @@ class TestOp: ...@@ -100,6 +103,7 @@ class TestOp:
rval = f() rval = f()
assert rval == 'test Op no input' assert rval == 'test Op no input'
class TestMakeThunk(unittest.TestCase): class TestMakeThunk(unittest.TestCase):
def test_no_c_code(self): def test_no_c_code(self):
class IncOnePython(Op): class IncOnePython(Op):
...@@ -121,28 +125,25 @@ class TestMakeThunk(unittest.TestCase): ...@@ -121,28 +125,25 @@ class TestMakeThunk(unittest.TestCase):
output, = outputs output, = outputs
output[0] = input + 1 output[0] = input + 1
i = scalar.int32('i') i = scalar.int32('i')
o = IncOnePython()(i) o = IncOnePython()(i)
# Check that the c_code function is not implemented # Check that the c_code function is not implemented
self.assertRaises((NotImplementedError, utils.MethodNotDefined), self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.c_code, o.owner.op.c_code,
o.owner, 'o', ['x'], 'z', {'fail': ''}) o.owner, 'o', ['x'], 'z', {'fail': ''})
storage_map = { storage_map = {i: [numpy.int32(3)],
i: [numpy.int32(3)], o: [None]}
o: [None]} compute_map = {i: [True],
compute_map = { o: [False]}
i: [True],
o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map, thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
no_recycling=[]) no_recycling=[])
required = thunk() required = thunk()
# Check everything went OK # Check everything went OK
assert not required # We provided all inputs assert not required # We provided all inputs
assert compute_map[o][0] assert compute_map[o][0]
assert storage_map[o][0] == 4 assert storage_map[o][0] == 4
...@@ -166,28 +167,25 @@ class TestMakeThunk(unittest.TestCase): ...@@ -166,28 +167,25 @@ class TestMakeThunk(unittest.TestCase):
z, = outputs z, = outputs
return "%(z)s = %(x)s + 1;" % locals() return "%(z)s = %(x)s + 1;" % locals()
i = scalar.int32('i') i = scalar.int32('i')
o = IncOneC()(i) o = IncOneC()(i)
# Check that the perform function is not implemented # Check that the perform function is not implemented
self.assertRaises((NotImplementedError, utils.MethodNotDefined), self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.perform, o.owner.op.perform,
o.owner, 0, [None]) o.owner, 0, [None])
storage_map = { storage_map = {i: [numpy.int32(3)],
i: [numpy.int32(3)], o: [None]}
o: [None]} compute_map = {i: [True],
compute_map = { o: [False]}
i: [True],
o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map, thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
no_recycling=[]) no_recycling=[])
if theano.config.cxx: if theano.config.cxx:
required = thunk() required = thunk()
# Check everything went OK # Check everything went OK
assert not required # We provided all inputs assert not required # We provided all inputs
assert compute_map[o][0] assert compute_map[o][0]
assert storage_map[o][0] == 4 assert storage_map[o][0] == 4
else: else:
...@@ -201,30 +199,33 @@ def test_test_value_python_objects(): ...@@ -201,30 +199,33 @@ def test_test_value_python_objects():
def test_test_value_ndarray(): def test_test_value_ndarray():
x = numpy.zeros((5,5)) x = numpy.zeros((5, 5))
v = op.get_test_value(x) v = op.get_test_value(x)
assert (v == x).all() assert (v == x).all()
def test_test_value_constant(): def test_test_value_constant():
x = T.as_tensor_variable(numpy.zeros((5,5))) x = T.as_tensor_variable(numpy.zeros((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5))) assert numpy.all(v == numpy.zeros((5, 5)))
def test_test_value_shared(): def test_test_value_shared():
x = shared(numpy.zeros((5,5))) x = shared(numpy.zeros((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5))) assert numpy.all(v == numpy.zeros((5, 5)))
def test_test_value_op(): def test_test_value_op():
try: try:
prev_value = config.compute_test_value prev_value = config.compute_test_value
config.compute_test_value = 'raise' config.compute_test_value = 'raise'
x = T.log(numpy.ones((5,5))) x = T.log(numpy.ones((5, 5)))
v = op.get_test_value(x) v = op.get_test_value(x)
assert numpy.allclose(v, numpy.zeros((5,5))) assert numpy.allclose(v, numpy.zeros((5, 5)))
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
...@@ -244,11 +245,11 @@ def test_get_debug_values_no_debugger(): ...@@ -244,11 +245,11 @@ def test_get_debug_values_no_debugger():
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
def test_get_det_debug_values_ignore(): def test_get_det_debug_values_ignore():
"""get_debug_values should return [] when debugger is ignore """get_debug_values should return [] when debugger is ignore
and some values are missing """ and some values are missing """
prev_value = config.compute_test_value prev_value = config.compute_test_value
try: try:
config.compute_test_value = 'ignore' config.compute_test_value = 'ignore'
...@@ -267,21 +268,21 @@ def test_get_debug_values_success(): ...@@ -267,21 +268,21 @@ def test_get_debug_values_success():
(and the debugger is on)""" (and the debugger is on)"""
prev_value = config.compute_test_value prev_value = config.compute_test_value
for mode in [ 'ignore', 'warn', 'raise' ]: for mode in ['ignore', 'warn', 'raise']:
try: try:
config.compute_test_value = mode config.compute_test_value = mode
x = T.vector() x = T.vector()
x.tag.test_value = numpy.zeros((4,), dtype=config.floatX) x.tag.test_value = numpy.zeros((4,), dtype=config.floatX)
y = numpy.zeros((5,5)) y = numpy.zeros((5, 5))
iters = 0 iters = 0
for x_val, y_val in op.get_debug_values(x, y): for x_val, y_val in op.get_debug_values(x, y):
assert x_val.shape == (4,) assert x_val.shape == (4,)
assert y_val.shape == (5,5) assert y_val.shape == (5, 5)
iters += 1 iters += 1
...@@ -290,6 +291,7 @@ def test_get_debug_values_success(): ...@@ -290,6 +291,7 @@ def test_get_debug_values_success():
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
def test_get_debug_values_exc(): def test_get_debug_values_exc():
"""tests that get_debug_value raises an exception when """tests that get_debug_value raises an exception when
debugger is set to raise and a value is missing """ debugger is set to raise and a value is missing """
...@@ -317,13 +319,14 @@ def test_get_debug_values_exc(): ...@@ -317,13 +319,14 @@ def test_get_debug_values_exc():
finally: finally:
config.compute_test_value = prev_value config.compute_test_value = prev_value
def test_debug_error_message(): def test_debug_error_message():
"""tests that debug_error_message raises an """tests that debug_error_message raises an
exception when it should.""" exception when it should."""
prev_value = config.compute_test_value prev_value = config.compute_test_value
for mode in [ 'ignore', 'raise' ]: for mode in ['ignore', 'raise']:
try: try:
config.compute_test_value = mode config.compute_test_value = mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论