提交 0f2d75be authored 作者: notoraptor's avatar notoraptor

Fix typos and move tests into a class.

上级 220e6ddb
...@@ -5,6 +5,7 @@ import theano ...@@ -5,6 +5,7 @@ import theano
from theano import Op, Apply, scalar from theano import Op, Apply, scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.gof.type import CDataType, EnumType, EnumList, CEnumType from theano.gof.type import CDataType, EnumType, EnumList, CEnumType
from unittest import TestCase
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -78,7 +79,7 @@ def test_cdata(): ...@@ -78,7 +79,7 @@ def test_cdata():
assert (v2 == v).all() assert (v2 == v).all()
class TestOpEnumList(Op): class MyOpEnumList(Op):
__props__ = ('op_chosen',) __props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long') params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long')
...@@ -120,10 +121,6 @@ class TestOpEnumList(Op): ...@@ -120,10 +121,6 @@ class TestOpEnumList(Op):
return (1,) return (1,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
a, b = inputs
o, = outputs
fail = sub['fail']
op = sub['params']
return """ return """
switch(%(op)s) { switch(%(op)s) {
case ADD: case ADD:
...@@ -142,10 +139,10 @@ class TestOpEnumList(Op): ...@@ -142,10 +139,10 @@ class TestOpEnumList(Op):
{%(fail)s} {%(fail)s}
break; break;
} }
""" % locals() """ % dict(op=sub['params'], o=outputs[0], a=inputs[0], b=inputs[1], fail=sub['fail'])
class TestOpCEnumType(Op): class MyOpCEnumType(Op):
__props__ = ('ctype_index',) __props__ = ('ctype_index',)
params_type = CEnumType('SIZE_INT', 'SIZE_FLOAT', 'SIZE_LONG_LONG', ctype='size_t') params_type = CEnumType('SIZE_INT', 'SIZE_FLOAT', 'SIZE_LONG_LONG', ctype='size_t')
...@@ -178,66 +175,67 @@ class TestOpCEnumType(Op): ...@@ -178,66 +175,67 @@ class TestOpCEnumType(Op):
return (1,) return (1,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
o, = outputs
# params in C code will already contains expected C constant value.
sizeof_ctype = sub['params']
return """ return """
%(o)s = %(sizeof_ctype)s; %(o)s = %(sizeof_ctype)s;
""" % locals() """ % dict(o=outputs[0],
# params in C code will already contains expected C constant value.
sizeof_ctype=sub['params'])
def test_enum_class(): class TestEnumTypes(TestCase):
# Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'): def test_enum_class(self):
try: # Check that invalid enum name raises exception.
EnumList(invalid_name) for invalid_name in ('a', '_A', '0'):
except AttributeError: try:
pass EnumList(invalid_name)
else: except AttributeError:
raise Exception('EnumList with invalid name should faild.') pass
else:
raise Exception('EnumList with invalid name should faild.')
try:
EnumType(**{invalid_name: 0})
except AttributeError:
pass
else:
raise Exception('EnumType with invalid name should fail.')
# Check that invalid enum value raises exception.
try: try:
EnumType(**{invalid_name: 0}) EnumType(INVALID_VALUE='string is not allowed.')
except AttributeError: except ValueError:
pass pass
else: else:
raise Exception('EnumType with invalid name should fail.') raise Exception('EnumType with invalid value should fail.')
# Check that invalid enum value raises exception. # Check EnumType.
try: e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0)
EnumType(INVALID_VALUE='string is not allowed.') e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0)
except ValueError: assert e1 == e2
pass assert not (e1 != e2)
else: assert hash(e1) == hash(e2)
raise Exception('EnumType with invalid value should fail.') # Check access to attributes.
assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7
# Check EnumType.
e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0) def test_op_with_enumlist(self):
e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0) a = scalar.int32()
assert e1 == e2 b = scalar.int32()
assert not (e1 != e2) c_add = MyOpEnumList('+')(a, b)
assert hash(e1) == hash(e2) c_sub = MyOpEnumList('-')(a, b)
c_multiply = MyOpEnumList('*')(a, b)
c_divide = MyOpEnumList('/')(a, b)
def test_op_with_enumlist(): f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
a = scalar.int32() va = 12
b = scalar.int32() vb = 15
c_add = TestOpEnumList('+')(a, b) ref = [va + vb, va - vb, va * vb, va // vb]
c_sub = TestOpEnumList('-')(a, b) out = f(va, vb)
c_multiply = TestOpEnumList('*')(a, b) assert ref == out, (ref, out)
c_divide = TestOpEnumList('/')(a, b)
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide]) def test_op_with_cenumtype(self):
va = 12 sizeof_int = MyOpCEnumType('int')()
vb = 15 sizeof_float = MyOpCEnumType('float')()
ref = [va + vb, va - vb, va * vb, va // vb] sizeof_long_long = MyOpCEnumType('long long')()
out = f(va, vb) f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
assert ref == out, (ref, out) out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='')
def test_op_with_cenumtype():
sizeof_int = TestOpCEnumType('int')()
sizeof_float = TestOpCEnumType('float')()
sizeof_long_long = TestOpCEnumType('long long')()
f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论