提交 94f092ae authored 作者: notoraptor's avatar notoraptor

Implement `CEnumType` as suggested by @abergeron .

Update tests. `CEnumType` tested, and it works. I keep previous implemented enum classes (EnumType, and EnumList). Add C type checking.
上级 723addd1
...@@ -55,7 +55,7 @@ from theano.gof.link import \ ...@@ -55,7 +55,7 @@ from theano.gof.link import \
from theano.gof.op import \ from theano.gof.op import \
Op, OpenMPOp, PureOp, COp, ops_with_inner_function Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.type import EnumType, EnumList from theano.gof.type import EnumType, EnumList, CEnumType
from theano.gof.opt import ( from theano.gof.opt import (
Optimizer, Optimizer,
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
import theano import theano
from theano import Op, Apply, scalar from theano import Op, Apply, scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.gof.type import CDataType, EnumType, EnumList from theano.gof.type import CDataType, EnumType, EnumList, CEnumType
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -80,7 +80,7 @@ def test_cdata(): ...@@ -80,7 +80,7 @@ def test_cdata():
class TestOpEnumList(Op): class TestOpEnumList(Op):
__props__ = ('op_chosen',) __props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE') params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long')
def __init__(self, choose_op): def __init__(self, choose_op):
assert self.params_type.ADD == 0 assert self.params_type.ADD == 0
...@@ -109,12 +109,15 @@ class TestOpEnumList(Op): ...@@ -109,12 +109,15 @@ class TestOpEnumList(Op):
elif op == self.params_type.MULTIPLY: elif op == self.params_type.MULTIPLY:
o[0] = a * b o[0] = a * b
elif op == self.params_type.DIVIDE: elif op == self.params_type.DIVIDE:
o[0] = a / b if any(dtype in theano.tensor.continuous_dtypes for dtype in (a.dtype, b.dtype)):
o[0] = a / b
else:
o[0] = a // b
else: else:
raise NotImplementedError('Unknown op id ' + str(op)) raise NotImplementedError('Unknown op id ' + str(op))
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return None
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
a, b = inputs a, b = inputs
...@@ -142,7 +145,48 @@ class TestOpEnumList(Op): ...@@ -142,7 +145,48 @@ class TestOpEnumList(Op):
""" % locals() """ % locals()
def test_enum(): class TestOpCEnumType(Op):
__props__ = ('ctype_index',)
params_type = CEnumType('SIZE_INT', 'SIZE_FLOAT', 'SIZE_LONG_LONG', ctype='size_t')
# Just for testing, we define our own macros.
def c_support_code(self):
return """
#define SIZE_INT sizeof(int)
#define SIZE_FLOAT sizeof(float)
#define SIZE_LONG_LONG sizeof(long long)
"""
def __init__(self, ctype):
# As we see, Python values of constants are not related to real C values
# (sizeof(int) will never be 0).
assert self.params_type.SIZE_INT == 0
assert self.params_type.SIZE_FLOAT == 1
assert self.params_type.SIZE_LONG_LONG == 2
d = {'int': self.params_type.SIZE_INT,
'float': self.params_type.SIZE_FLOAT,
'long long': self.params_type.SIZE_LONG_LONG}
self.ctype_index = d[ctype]
def get_params(self, node):
return self.ctype_index
def make_node(self):
return Apply(self, [], [scalar.uint32()])
def c_code_cache_version(self):
return None
def c_code(self, node, name, inputs, outputs, sub):
o, = outputs
ctype_index = sub['params']
return """
/* ctype_index already contains expected C constant value. */
%(o)s = %(ctype_index)s;
""" % locals()
def test_enum_class():
# Check that invalid enum name raises exception. # Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'): for invalid_name in ('a', '_A', '0'):
try: try:
...@@ -174,6 +218,8 @@ def test_enum(): ...@@ -174,6 +218,8 @@ def test_enum():
assert not (e1 != e2) assert not (e1 != e2)
assert hash(e1) == hash(e2) assert hash(e1) == hash(e2)
def test_op_with_enumlist():
# Test an op with EnumList. # Test an op with EnumList.
a = scalar.int32() a = scalar.int32()
b = scalar.int32() b = scalar.int32()
...@@ -184,10 +230,15 @@ def test_enum(): ...@@ -184,10 +230,15 @@ def test_enum():
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide]) f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
va = 12 va = 12
vb = 15 vb = 15
ref_add = va + vb ref = [va + vb, va - vb, va * vb, va // vb]
ref_sub = va - vb
ref_multiply = va * vb
ref_divide = va // vb
ref = [ref_add, ref_sub, ref_multiply, ref_divide]
out = f(va, vb) out = f(va, vb)
assert ref == out, (ref, out) assert ref == out, (ref, out)
def test_op_with_cenumtype():
sizeof_int = TestOpCEnumType('int')()
sizeof_float = TestOpCEnumType('float')()
sizeof_long_long = TestOpCEnumType('long long')()
f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='', end='')
...@@ -814,15 +814,18 @@ CDataType.Constant = CDataTypeConstant ...@@ -814,15 +814,18 @@ CDataType.Constant = CDataTypeConstant
class EnumType(Type, dict): class EnumType(Type, dict):
""" """
Class that allows to create enumerations of constant values. Class that allows to create enumerations of constant values.
Constants are available as object attributes in Python code and as macro-defined constants in C code.
Constants can be floating values, integers, or booleans (automatically converted to integers). - Constants are available as object attributes in Python code and as macro-defined constants in C code.
Constants name must start with a capital letter and contain capital letters, underscores or digits. - Constants can be floating values, integers, or booleans (automatically converted to integers).
- Constants name must start with a capital letter and contain capital letters, underscores or digits.
This type is intended to be used as op parameter type. This type is intended to be used as op parameter type.
Example:: Example::
enum = EnumType(CONSTANT_1=0, CONSTANT_2=1, CONSTANT_3=2.5, CONSTANT_4=False, CONSTANT_5=True) enum = EnumType(CONSTANT_1=0, CONSTANT_2=1, CONSTANT_3=2.5, CONSTANT_4=False, CONSTANT_5=True)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4, enum.CONSTANT_5) print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4, enum.CONSTANT_5)
# will print 0 1 2.5 0 1
In C code: In C code:
...@@ -834,51 +837,74 @@ class EnumType(Type, dict): ...@@ -834,51 +837,74 @@ class EnumType(Type, dict):
int constant_4 = CONSTANT_4; // constant_4 == 0 int constant_4 = CONSTANT_4; // constant_4 == 0
int constant_5 = CONSTANT_5; // constant_5 == 1 int constant_5 = CONSTANT_5; // constant_5 == 1
You can also specify a C type if you want to use an op param to handle these enum values.
Default C type is ``double``.
.. code-block:: python
enum = EnumType(CONSTANT_1=0, CONSTANT_2=1, CONSTANT_3=2, ctype='size_t')
..note:: In C code:
.. code-block:: c
size_t op_param = CONSTANT_1;
.. note::
This Type is not complete and should never be used for regular graph operations. This Type is not complete and should never be used for regular graph operations.
""" """
def check_ctype(self):
# C type may be a list of keywords, e.g. "unsigned long long".
# We should check each part.
if not all(re.match('^[A-Za-z_][A-Za-z0-9_]*$', el) for el in self.ctype.split()):
raise TypeError('%s: invalid C type' % type(self).__name__)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.ctype = kwargs.pop('ctype', 'double')
self.check_ctype()
for k in kwargs: for k in kwargs:
if re.match('^[A-Z][A-Z0-9_]*$', k) is None: if re.match('^[A-Z][A-Z0-9_]*$', k) is None:
raise AttributeError('EnumType: invalid enum name: "%s". ' raise AttributeError('%s: invalid enum name: "%s". '
'Only capital letters, underscores and digits ' 'Only capital letters, underscores and digits '
'are allowed.' % k) 'are allowed.' % (type(self).__name__, k))
if isinstance(kwargs[k], bool): if isinstance(kwargs[k], bool):
kwargs[k] = int(kwargs[k]) kwargs[k] = int(kwargs[k])
elif not isinstance(kwargs[k], (int, float)): elif not isinstance(kwargs[k], (int, float)):
raise ValueError('EnumType: enum "%s": expected integer or floating value, got "%s".' raise ValueError('%s: constant "%s": expected integer or floating value, got "%s".'
% (k, type(kwargs[k]).__name__)) % (type(self).__name__, k, type(kwargs[k]).__name__))
super(EnumType, self).__init__(**kwargs) super(EnumType, self).__init__(**kwargs)
def __repr__(self): def __repr__(self):
return 'EnumType(%s)' % ', '.join('%s:%s' % (k, self[k]) for k in sorted(self.keys())) return '%s(%s)' % (type(self).__name__, ', '.join('%s:%s' % (k, self[k]) for k in sorted(self.keys())))
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self:
return self[key] return self[key]
if key == 'ctype':
return self.ctype
return Type.__getattr__(self, key) return Type.__getattr__(self, key)
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key in self: if key in self:
raise NotImplementedError('EnumType values are immutable.') raise NotImplementedError('constant values are immutable.')
Type.__setattr__(self, key, value) Type.__setattr__(self, key, value)
def __setitem__(self, key, value): def __setitem__(self, key, value):
raise NotImplementedError('EnumType values are immutable.') raise NotImplementedError('constant values are immutable.')
def __delitem__(self, key): def __delitem__(self, key):
raise NotImplementedError('EnumType values are immutable.') raise NotImplementedError('constant values are immutable.')
def __hash__(self): def __hash__(self):
# All values are Python basic types, then easy to hash. # All values are Python basic types, then easy to hash.
return hash((type(self),) + tuple((k, self[k]) for k in sorted(self.keys()))) return hash((type(self), self.ctype) + tuple((k, self[k]) for k in sorted(self.keys())))
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
self.ctype == other.ctype and
len(self) == len(other) and len(self) == len(other) and
all(k in other for k in self) and all(k in other for k in self) and
all(self[k] == other[k] for k in self)) all(self[k] == other[k] for k in self))
...@@ -886,15 +912,13 @@ class EnumType(Type, dict): ...@@ -886,15 +912,13 @@ class EnumType(Type, dict):
# EnumType should be used to create constants available in both Python and C code. # EnumType should be used to create constants available in both Python and C code.
# However, for convenience, we make sure EnumType can have a value, like other common types, # However, for convenience, we make sure EnumType can have a value, like other common types,
# such that it could be used as-is as an op param. # such that it could be used as-is as an op param.
# As we currently allow enum constants to be booleans, integers or floating values, # C type of value is defined in self.ctype.
# we choose the biggest basic type (i.e. float) as type of enum values.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if strict: if not strict and isinstance(data, bool):
assert isinstance(data, float) data = int(data)
return data assert isinstance(data, (int, float))
assert isinstance(data, (bool, int, float)) return data
return float(data)
def values_eq(self, a, b): def values_eq(self, a, b):
return a == b return a == b
...@@ -902,13 +926,29 @@ class EnumType(Type, dict): ...@@ -902,13 +926,29 @@ class EnumType(Type, dict):
def values_eq_approx(self, a, b): def values_eq_approx(self, a, b):
return float(a) == float(b) return float(a) == float(b)
@staticmethod
def c_support_macro_code():
return """
#if PY_MAJOR_VERSION >= 3
#ifndef PyInt_Check
#define PyInt_Check PyLong_Check
#endif
#ifndef PyInt_AsLong
#define PyInt_AsLong PyLong_AsLong
#endif
#endif
"""
def c_support_code(self): def c_support_code(self):
return ''.join(""" return (
#define %s %s self.c_support_macro_code() +
""" % (k, str(self[k])) for k in sorted(self.keys())) ''.join("""
#define %s %s
""" % (k, str(self[k])) for k in sorted(self.keys()))
)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """double %(name)s;""" % locals() return """%(ctype)s %(name)s;""" % dict(ctype=self.ctype, name=name)
def c_init(self, name, sub): def c_init(self, name, sub):
return "%(name)s = 0;" % locals() return "%(name)s = 0;" % locals()
...@@ -918,28 +958,90 @@ class EnumType(Type, dict): ...@@ -918,28 +958,90 @@ class EnumType(Type, dict):
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
return """ return """
%(name)s = PyFloat_AsDouble(py_%(name)s); if (PyInt_Check(py_%(name)s)) {
%(name)s = (%(ctype)s)PyInt_AsLong(py_%(name)s);
} else {
%(name)s = (%(ctype)s)PyFloat_AsDouble(py_%(name)s);
}
if (PyErr_Occurred()) { if (PyErr_Occurred()) {
%(name)s = 0; %(fail)s
} }
""" % locals() """ % dict(ctype=self.ctype, name=name, fail=sub['fail'])
class EnumList(EnumType): class EnumList(EnumType):
"""" """
Class that allows to create enumeration of constant integer values. Class that allows to create enumeration of constant values.
Same as :class:`EnumType`, but automatically gives an unique integer value to each constant in a list of Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of
constants names (constant at index i in the list will receive value i, i from ``0`` to ``len(constants)-1``). constants names (constant at index ``i`` in the list will receive value ``i``,
with ``i`` from ``0`` to ``len(constants) - 1``).
Example:: Example::
enum = EnumList(CONSTANT_1, CONSTANT_2, CONSTANT_3, CONSTANT_4, CONSTANT_5) enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', 'CONSTANT_5')
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4, enum.CONSTANT_5) print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4, enum.CONSTANT_5)
# will print: 0 1 2 3 4 # will print: 0 1 2 3 4
Like :class:`EnumType`, you can also define the C type for a variable able to handle these enum values.
Default C type is ``int``::
enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int')
""" """
def __init__(self, *args): def __init__(self, *args, **kwargs):
assert len(kwargs) == 0 or (len(kwargs) == 1 and 'ctype' in kwargs), \
type(self).__name__ + ': expected 0 or only 1 extra parameter "ctype".'
ctype = kwargs.pop('ctype', 'int')
if len(args) > len(set(args)): if len(args) > len(set(args)):
raise AttributeError('EnumList: some constants names are duplicated.') raise AttributeError(type(self).__name__ + ': some constants names are duplicated.')
super(EnumList, self).__init__(**{const_name: const_rank for (const_rank, const_name) in enumerate(args)})
kwargs = {const_name: const_rank for (const_rank, const_name) in enumerate(args)}
kwargs.update(ctype=ctype)
super(EnumList, self).__init__(**kwargs)
class CEnumType(EnumList):
"""
Class that allows to create enumeration of constant values that represent C-defined constants.
- Constant should have same names as in C.
- In Python, constants will have arbitrary-defined values.
They should be used only for choices, not for its values.
- In C code, the real values defined in C will be used.
They could be used either for choices or for its real values.
Like :class:`EnumList`, you can also define the C type for a variable able to handle these enum values.
Default C type is ``int``.
.. code-block:: python
enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long')
.. note::
Be sure C constants are available in your C code. If they come from a C header, consider implementing
``c_headers()`` and ``c_header_dirs()`` in the Op class which you use CEnumType as op parameters type.
"""
def c_support_code(self):
return self.c_support_macro_code()
def c_extract(self, name, sub, check_input=True):
swapped_dict = dict((v, k) for (k, v) in self.items())
# swapped_dict's keys are integers.
return """
switch(PyInt_AsLong(py_%(name)s)) {
%(cases)s
default:
{%(fail)s}
break;
}
""" % dict(name=name,
cases=''.join("""
case %(i)d: %(name)s = %(constant_cname)s; break;
""" % dict(i=i, name=name, constant_cname=swapped_dict[i]) for i in sorted(swapped_dict.keys())),
fail=sub['fail'])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论