提交 94f092ae authored 作者: notoraptor's avatar notoraptor

Implement `CEnumType` as suggested by @abergeron .

Update tests. `CEnumType` tested, and it works. I keep previous implemented enum classes (EnumType, and EnumList). Add C type checking.
上级 723addd1
......@@ -55,7 +55,7 @@ from theano.gof.link import \
from theano.gof.op import \
Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.type import EnumType, EnumList
from theano.gof.type import EnumType, EnumList, CEnumType
from theano.gof.opt import (
Optimizer,
......
......@@ -4,7 +4,7 @@ import numpy as np
import theano
from theano import Op, Apply, scalar
from theano.tensor import TensorType
from theano.gof.type import CDataType, EnumType, EnumList
from theano.gof.type import CDataType, EnumType, EnumList, CEnumType
from nose.plugins.skip import SkipTest
......@@ -80,7 +80,7 @@ def test_cdata():
class TestOpEnumList(Op):
__props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE')
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long')
def __init__(self, choose_op):
assert self.params_type.ADD == 0
......@@ -109,12 +109,15 @@ class TestOpEnumList(Op):
elif op == self.params_type.MULTIPLY:
o[0] = a * b
elif op == self.params_type.DIVIDE:
o[0] = a / b
if any(dtype in theano.tensor.continuous_dtypes for dtype in (a.dtype, b.dtype)):
o[0] = a / b
else:
o[0] = a // b
else:
raise NotImplementedError('Unknown op id ' + str(op))
def c_code_cache_version(self):
return (1,)
return None
def c_code(self, node, name, inputs, outputs, sub):
a, b = inputs
......@@ -142,7 +145,48 @@ class TestOpEnumList(Op):
""" % locals()
def test_enum():
class TestOpCEnumType(Op):
__props__ = ('ctype_index',)
params_type = CEnumType('SIZE_INT', 'SIZE_FLOAT', 'SIZE_LONG_LONG', ctype='size_t')
# Just for testing, we define our own macros.
def c_support_code(self):
return """
#define SIZE_INT sizeof(int)
#define SIZE_FLOAT sizeof(float)
#define SIZE_LONG_LONG sizeof(long long)
"""
def __init__(self, ctype):
# As we see, Python values of constants are not related to real C values
# (sizeof(int) will never be 0).
assert self.params_type.SIZE_INT == 0
assert self.params_type.SIZE_FLOAT == 1
assert self.params_type.SIZE_LONG_LONG == 2
d = {'int': self.params_type.SIZE_INT,
'float': self.params_type.SIZE_FLOAT,
'long long': self.params_type.SIZE_LONG_LONG}
self.ctype_index = d[ctype]
def get_params(self, node):
return self.ctype_index
def make_node(self):
return Apply(self, [], [scalar.uint32()])
def c_code_cache_version(self):
return None
def c_code(self, node, name, inputs, outputs, sub):
o, = outputs
ctype_index = sub['params']
return """
/* ctype_index already contains expected C constant value. */
%(o)s = %(ctype_index)s;
""" % locals()
def test_enum_class():
# Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'):
try:
......@@ -174,6 +218,8 @@ def test_enum():
assert not (e1 != e2)
assert hash(e1) == hash(e2)
def test_op_with_enumlist():
# Test an op with EnumList.
a = scalar.int32()
b = scalar.int32()
......@@ -184,10 +230,15 @@ def test_enum():
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
va = 12
vb = 15
ref_add = va + vb
ref_sub = va - vb
ref_multiply = va * vb
ref_divide = va // vb
ref = [ref_add, ref_sub, ref_multiply, ref_divide]
ref = [va + vb, va - vb, va * vb, va // vb]
out = f(va, vb)
assert ref == out, (ref, out)
def test_op_with_cenumtype():
sizeof_int = TestOpCEnumType('int')()
sizeof_float = TestOpCEnumType('float')()
sizeof_long_long = TestOpCEnumType('long long')()
f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='', end='')
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论