提交 9ddd9748 authored 作者: Frederic's avatar Frederic

pep8

上级 8e512ce0
......@@ -45,6 +45,7 @@ class MyType(Type):
raise ValueError("Invalid value")
return x
class MyOp(Op):
def make_node(self, *inputs):
......@@ -81,14 +82,16 @@ class TestOp:
def test_sanity_0(self):
r1, r2 = MyType(1)(), MyType(2)()
node = MyOp.make_node(r1, r2)
assert [x for x in node.inputs] == [r1, r2] # Are the inputs what I provided?
assert [x.type for x in node.outputs] == [MyType(3)] # Are the outputs what I expect?
# Are the inputs what I provided?
assert [x for x in node.inputs] == [r1, r2]
# Are the outputs what I expect?
assert [x.type for x in node.outputs] == [MyType(3)]
assert node.outputs[0].owner is node and node.outputs[0].index == 0
# validate
def test_validate(self):
try:
MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances
MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances
raise Exception("Expected an exception")
except Exception, e:
if str(e) != "Error 1":
......@@ -100,6 +103,7 @@ class TestOp:
rval = f()
assert rval == 'test Op no input'
class TestMakeThunk(unittest.TestCase):
def test_no_c_code(self):
class IncOnePython(Op):
......@@ -121,28 +125,25 @@ class TestMakeThunk(unittest.TestCase):
output, = outputs
output[0] = input + 1
i = scalar.int32('i')
o = IncOnePython()(i)
# Check that the c_code function is not implemented
self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.c_code,
o.owner, 'o', ['x'], 'z', {'fail': ''})
o.owner.op.c_code,
o.owner, 'o', ['x'], 'z', {'fail': ''})
storage_map = {
i: [numpy.int32(3)],
o: [None]}
compute_map = {
i: [True],
o: [False]}
storage_map = {i: [numpy.int32(3)],
o: [None]}
compute_map = {i: [True],
o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
no_recycling=[])
no_recycling=[])
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
......@@ -166,28 +167,25 @@ class TestMakeThunk(unittest.TestCase):
z, = outputs
return "%(z)s = %(x)s + 1;" % locals()
i = scalar.int32('i')
o = IncOneC()(i)
# Check that the perform function is not implemented
self.assertRaises((NotImplementedError, utils.MethodNotDefined),
o.owner.op.perform,
o.owner, 0, [None])
o.owner.op.perform,
o.owner, 0, [None])
storage_map = {
i: [numpy.int32(3)],
o: [None]}
compute_map = {
i: [True],
o: [False]}
storage_map = {i: [numpy.int32(3)],
o: [None]}
compute_map = {i: [True],
o: [False]}
thunk = o.owner.op.make_thunk(o.owner, storage_map, compute_map,
no_recycling=[])
no_recycling=[])
if theano.config.cxx:
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
......@@ -201,30 +199,33 @@ def test_test_value_python_objects():
def test_test_value_ndarray():
x = numpy.zeros((5,5))
x = numpy.zeros((5, 5))
v = op.get_test_value(x)
assert (v == x).all()
def test_test_value_constant():
x = T.as_tensor_variable(numpy.zeros((5,5)))
x = T.as_tensor_variable(numpy.zeros((5, 5)))
v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5)))
assert numpy.all(v == numpy.zeros((5, 5)))
def test_test_value_shared():
x = shared(numpy.zeros((5,5)))
x = shared(numpy.zeros((5, 5)))
v = op.get_test_value(x)
assert numpy.all(v == numpy.zeros((5,5)))
assert numpy.all(v == numpy.zeros((5, 5)))
def test_test_value_op():
try:
prev_value = config.compute_test_value
config.compute_test_value = 'raise'
x = T.log(numpy.ones((5,5)))
x = T.log(numpy.ones((5, 5)))
v = op.get_test_value(x)
assert numpy.allclose(v, numpy.zeros((5,5)))
assert numpy.allclose(v, numpy.zeros((5, 5)))
finally:
config.compute_test_value = prev_value
......@@ -244,11 +245,11 @@ def test_get_debug_values_no_debugger():
finally:
config.compute_test_value = prev_value
def test_get_det_debug_values_ignore():
"""get_debug_values should return [] when debugger is ignore
and some values are missing """
prev_value = config.compute_test_value
try:
config.compute_test_value = 'ignore'
......@@ -267,21 +268,21 @@ def test_get_debug_values_success():
(and the debugger is on)"""
prev_value = config.compute_test_value
for mode in [ 'ignore', 'warn', 'raise' ]:
for mode in ['ignore', 'warn', 'raise']:
try:
config.compute_test_value = mode
x = T.vector()
x.tag.test_value = numpy.zeros((4,), dtype=config.floatX)
y = numpy.zeros((5,5))
y = numpy.zeros((5, 5))
iters = 0
for x_val, y_val in op.get_debug_values(x, y):
assert x_val.shape == (4,)
assert y_val.shape == (5,5)
assert y_val.shape == (5, 5)
iters += 1
......@@ -290,6 +291,7 @@ def test_get_debug_values_success():
finally:
config.compute_test_value = prev_value
def test_get_debug_values_exc():
"""tests that get_debug_value raises an exception when
debugger is set to raise and a value is missing """
......@@ -317,13 +319,14 @@ def test_get_debug_values_exc():
finally:
config.compute_test_value = prev_value
def test_debug_error_message():
"""tests that debug_error_message raises an
exception when it should."""
prev_value = config.compute_test_value
for mode in [ 'ignore', 'raise' ]:
for mode in ['ignore', 'raise']:
try:
config.compute_test_value = mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论