提交 94f092ae authored 作者: notoraptor's avatar notoraptor

Implement `CEnumType` as suggested by @abergeron .

Update tests. `CEnumType` tested, and it works. I keep previous implemented enum classes (EnumType, and EnumList). Add C type checking.
上级 723addd1
...@@ -55,7 +55,7 @@ from theano.gof.link import \ ...@@ -55,7 +55,7 @@ from theano.gof.link import \
from theano.gof.op import \ from theano.gof.op import \
Op, OpenMPOp, PureOp, COp, ops_with_inner_function Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.type import EnumType, EnumList from theano.gof.type import EnumType, EnumList, CEnumType
from theano.gof.opt import ( from theano.gof.opt import (
Optimizer, Optimizer,
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import theano import theano
from theano import Op, Apply, scalar from theano import Op, Apply, scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.gof.type import CDataType, EnumType, EnumList from theano.gof.type import CDataType, EnumType, EnumList, CEnumType
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -80,7 +80,7 @@ def test_cdata(): ...@@ -80,7 +80,7 @@ def test_cdata():
class TestOpEnumList(Op): class TestOpEnumList(Op):
__props__ = ('op_chosen',) __props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE') params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long')
def __init__(self, choose_op): def __init__(self, choose_op):
assert self.params_type.ADD == 0 assert self.params_type.ADD == 0
...@@ -109,12 +109,15 @@ class TestOpEnumList(Op): ...@@ -109,12 +109,15 @@ class TestOpEnumList(Op):
elif op == self.params_type.MULTIPLY: elif op == self.params_type.MULTIPLY:
o[0] = a * b o[0] = a * b
elif op == self.params_type.DIVIDE: elif op == self.params_type.DIVIDE:
o[0] = a / b if any(dtype in theano.tensor.continuous_dtypes for dtype in (a.dtype, b.dtype)):
o[0] = a / b
else:
o[0] = a // b
else: else:
raise NotImplementedError('Unknown op id ' + str(op)) raise NotImplementedError('Unknown op id ' + str(op))
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return None
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
a, b = inputs a, b = inputs
...@@ -142,7 +145,48 @@ class TestOpEnumList(Op): ...@@ -142,7 +145,48 @@ class TestOpEnumList(Op):
""" % locals() """ % locals()
def test_enum(): class TestOpCEnumType(Op):
__props__ = ('ctype_index',)
params_type = CEnumType('SIZE_INT', 'SIZE_FLOAT', 'SIZE_LONG_LONG', ctype='size_t')
# Just for testing, we define our own macros.
def c_support_code(self):
return """
#define SIZE_INT sizeof(int)
#define SIZE_FLOAT sizeof(float)
#define SIZE_LONG_LONG sizeof(long long)
"""
def __init__(self, ctype):
# As we see, Python values of constants are not related to real C values
# (sizeof(int) will never be 0).
assert self.params_type.SIZE_INT == 0
assert self.params_type.SIZE_FLOAT == 1
assert self.params_type.SIZE_LONG_LONG == 2
d = {'int': self.params_type.SIZE_INT,
'float': self.params_type.SIZE_FLOAT,
'long long': self.params_type.SIZE_LONG_LONG}
self.ctype_index = d[ctype]
def get_params(self, node):
return self.ctype_index
def make_node(self):
return Apply(self, [], [scalar.uint32()])
def c_code_cache_version(self):
return None
def c_code(self, node, name, inputs, outputs, sub):
o, = outputs
ctype_index = sub['params']
return """
/* ctype_index already contains expected C constant value. */
%(o)s = %(ctype_index)s;
""" % locals()
def test_enum_class():
# Check that invalid enum name raises exception. # Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'): for invalid_name in ('a', '_A', '0'):
try: try:
...@@ -174,6 +218,8 @@ def test_enum(): ...@@ -174,6 +218,8 @@ def test_enum():
assert not (e1 != e2) assert not (e1 != e2)
assert hash(e1) == hash(e2) assert hash(e1) == hash(e2)
def test_op_with_enumlist():
# Test an op with EnumList. # Test an op with EnumList.
a = scalar.int32() a = scalar.int32()
b = scalar.int32() b = scalar.int32()
...@@ -184,10 +230,15 @@ def test_enum(): ...@@ -184,10 +230,15 @@ def test_enum():
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide]) f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
va = 12 va = 12
vb = 15 vb = 15
ref_add = va + vb ref = [va + vb, va - vb, va * vb, va // vb]
ref_sub = va - vb
ref_multiply = va * vb
ref_divide = va // vb
ref = [ref_add, ref_sub, ref_multiply, ref_divide]
out = f(va, vb) out = f(va, vb)
assert ref == out, (ref, out) assert ref == out, (ref, out)
def test_op_with_cenumtype():
sizeof_int = TestOpCEnumType('int')()
sizeof_float = TestOpCEnumType('float')()
sizeof_long_long = TestOpCEnumType('long long')()
f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='', end='')
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论