提交 0f2d75be authored 作者: notoraptor's avatar notoraptor

Fix typos and move tests into a class.

上级 220e6ddb
......@@ -5,6 +5,7 @@ import theano
from theano import Op, Apply, scalar
from theano.tensor import TensorType
from theano.gof.type import CDataType, EnumType, EnumList, CEnumType
from unittest import TestCase
from nose.plugins.skip import SkipTest
......@@ -78,7 +79,7 @@ def test_cdata():
assert (v2 == v).all()
class TestOpEnumList(Op):
class MyOpEnumList(Op):
__props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long')
......@@ -120,10 +121,6 @@ class TestOpEnumList(Op):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
a, b = inputs
o, = outputs
fail = sub['fail']
op = sub['params']
return """
switch(%(op)s) {
case ADD:
......@@ -142,10 +139,10 @@ class TestOpEnumList(Op):
{%(fail)s}
break;
}
""" % locals()
""" % dict(op=sub['params'], o=outputs[0], a=inputs[0], b=inputs[1], fail=sub['fail'])
class TestOpCEnumType(Op):
class MyOpCEnumType(Op):
__props__ = ('ctype_index',)
params_type = CEnumType('SIZE_INT', 'SIZE_FLOAT', 'SIZE_LONG_LONG', ctype='size_t')
......@@ -178,66 +175,67 @@ class TestOpCEnumType(Op):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
o, = outputs
# params in C code will already contains expected C constant value.
sizeof_ctype = sub['params']
return """
%(o)s = %(sizeof_ctype)s;
""" % locals()
""" % dict(o=outputs[0],
# params in C code will already contains expected C constant value.
sizeof_ctype=sub['params'])
def test_enum_class():
# Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'):
try:
EnumList(invalid_name)
except AttributeError:
pass
else:
raise Exception('EnumList with invalid name should faild.')
class TestEnumTypes(TestCase):
def test_enum_class(self):
# Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'):
try:
EnumList(invalid_name)
except AttributeError:
pass
else:
raise Exception('EnumList with invalid name should faild.')
try:
EnumType(**{invalid_name: 0})
except AttributeError:
pass
else:
raise Exception('EnumType with invalid name should fail.')
# Check that invalid enum value raises exception.
try:
EnumType(**{invalid_name: 0})
except AttributeError:
EnumType(INVALID_VALUE='string is not allowed.')
except ValueError:
pass
else:
raise Exception('EnumType with invalid name should fail.')
# Check that invalid enum value raises exception.
try:
EnumType(INVALID_VALUE='string is not allowed.')
except ValueError:
pass
else:
raise Exception('EnumType with invalid value should fail.')
# Check EnumType.
e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0)
e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0)
assert e1 == e2
assert not (e1 != e2)
assert hash(e1) == hash(e2)
def test_op_with_enumlist():
a = scalar.int32()
b = scalar.int32()
c_add = TestOpEnumList('+')(a, b)
c_sub = TestOpEnumList('-')(a, b)
c_multiply = TestOpEnumList('*')(a, b)
c_divide = TestOpEnumList('/')(a, b)
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
va = 12
vb = 15
ref = [va + vb, va - vb, va * vb, va // vb]
out = f(va, vb)
assert ref == out, (ref, out)
def test_op_with_cenumtype():
sizeof_int = TestOpCEnumType('int')()
sizeof_float = TestOpCEnumType('float')()
sizeof_long_long = TestOpCEnumType('long long')()
f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='')
raise Exception('EnumType with invalid value should fail.')
# Check EnumType.
e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0)
e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0)
assert e1 == e2
assert not (e1 != e2)
assert hash(e1) == hash(e2)
# Check access to attributes.
assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7
def test_op_with_enumlist(self):
a = scalar.int32()
b = scalar.int32()
c_add = MyOpEnumList('+')(a, b)
c_sub = MyOpEnumList('-')(a, b)
c_multiply = MyOpEnumList('*')(a, b)
c_divide = MyOpEnumList('/')(a, b)
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
va = 12
vb = 15
ref = [va + vb, va - vb, va * vb, va // vb]
out = f(va, vb)
assert ref == out, (ref, out)
def test_op_with_cenumtype(self):
sizeof_int = MyOpCEnumType('int')()
sizeof_float = MyOpCEnumType('float')()
sizeof_long_long = MyOpCEnumType('long long')()
f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论