提交 00935985 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5710 from notoraptor/create-c-constants-in-cops

Implement new op param type `EnumType`
......@@ -55,6 +55,8 @@ from theano.gof.link import \
from theano.gof.op import \
Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.type import EnumType, EnumList, CEnumType
from theano.gof.opt import (
Optimizer,
optimizer, inplace_optimizer,
......
......@@ -2,9 +2,10 @@ from __future__ import absolute_import, print_function, division
import numpy as np
import theano
from theano import Op, Apply
from theano import Op, Apply, scalar
from theano.tensor import TensorType
from theano.gof.type import CDataType
from theano.gof.type import CDataType, EnumType, EnumList, CEnumType
from unittest import TestCase
from nose.plugins.skip import SkipTest
......@@ -76,3 +77,165 @@ def test_cdata():
v2 = f(v)
assert (v2 == v).all()
class MyOpEnumList(Op):
__props__ = ('op_chosen',)
params_type = EnumList('ADD', 'SUB', 'MULTIPLY', 'DIVIDE', ctype='unsigned long long')
def __init__(self, choose_op):
assert self.params_type.ADD == 0
assert self.params_type.SUB == 1
assert self.params_type.MULTIPLY == 2
assert self.params_type.DIVIDE == 3
op_to_const = {'+': self.params_type.ADD,
'-': self.params_type.SUB,
'*': self.params_type.MULTIPLY,
'/': self.params_type.DIVIDE}
self.op_chosen = op_to_const[choose_op]
def get_params(self, node):
return self.op_chosen
def make_node(self, a, b):
return Apply(self, [scalar.as_scalar(a), scalar.as_scalar(b)], [scalar.float64()])
def perform(self, node, inputs, outputs, op):
a, b = inputs
o, = outputs
if op == self.params_type.ADD:
o[0] = a + b
elif op == self.params_type.SUB:
o[0] = a - b
elif op == self.params_type.MULTIPLY:
o[0] = a * b
elif op == self.params_type.DIVIDE:
if any(dtype in theano.tensor.continuous_dtypes for dtype in (a.dtype, b.dtype)):
o[0] = a / b
else:
o[0] = a // b
else:
raise NotImplementedError('Unknown op id ' + str(op))
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
return """
switch(%(op)s) {
case ADD:
%(o)s = %(a)s + %(b)s;
break;
case SUB:
%(o)s = %(a)s - %(b)s;
break;
case MULTIPLY:
%(o)s = %(a)s * %(b)s;
break;
case DIVIDE:
%(o)s = %(a)s / %(b)s;
break;
default:
{%(fail)s}
break;
}
""" % dict(op=sub['params'], o=outputs[0], a=inputs[0], b=inputs[1], fail=sub['fail'])
class MyOpCEnumType(Op):
__props__ = ('ctype_index',)
params_type = CEnumType('SIZE_INT', 'SIZE_FLOAT', 'SIZE_LONG_LONG', ctype='size_t')
# Just for testing, we define our own macros.
def c_support_code(self):
return """
#define SIZE_INT sizeof(int)
#define SIZE_FLOAT sizeof(float)
#define SIZE_LONG_LONG sizeof(long long)
"""
def __init__(self, ctype):
# As we see, Python values of constants are not related to real C values
# (sizeof(int) will never be 0).
assert self.params_type.SIZE_INT == 0
assert self.params_type.SIZE_FLOAT == 1
assert self.params_type.SIZE_LONG_LONG == 2
d = {'int': self.params_type.SIZE_INT,
'float': self.params_type.SIZE_FLOAT,
'long long': self.params_type.SIZE_LONG_LONG}
self.ctype_index = d[ctype]
def get_params(self, node):
return self.ctype_index
def make_node(self):
return Apply(self, [], [scalar.uint32()])
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
return """
%(o)s = %(sizeof_ctype)s;
""" % dict(o=outputs[0],
# params in C code will already contains expected C constant value.
sizeof_ctype=sub['params'])
class TestEnumTypes(TestCase):
def test_enum_class(self):
# Check that invalid enum name raises exception.
for invalid_name in ('a', '_A', '0'):
try:
EnumList(invalid_name)
except AttributeError:
pass
else:
raise Exception('EnumList with invalid name should faild.')
try:
EnumType(**{invalid_name: 0})
except AttributeError:
pass
else:
raise Exception('EnumType with invalid name should fail.')
# Check that invalid enum value raises exception.
try:
EnumType(INVALID_VALUE='string is not allowed.')
except ValueError:
pass
else:
raise Exception('EnumType with invalid value should fail.')
# Check EnumType.
e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0)
e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0)
assert e1 == e2
assert not (e1 != e2)
assert hash(e1) == hash(e2)
# Check access to attributes.
assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7
def test_op_with_enumlist(self):
a = scalar.int32()
b = scalar.int32()
c_add = MyOpEnumList('+')(a, b)
c_sub = MyOpEnumList('-')(a, b)
c_multiply = MyOpEnumList('*')(a, b)
c_divide = MyOpEnumList('/')(a, b)
f = theano.function([a, b], [c_add, c_sub, c_multiply, c_divide])
va = 12
vb = 15
ref = [va + vb, va - vb, va * vb, va // vb]
out = f(va, vb)
assert ref == out, (ref, out)
def test_op_with_cenumtype(self):
sizeof_int = MyOpCEnumType('int')()
sizeof_float = MyOpCEnumType('float')()
sizeof_long_long = MyOpCEnumType('long long')()
f = theano.function([], [sizeof_int, sizeof_float, sizeof_long_long])
out = f()
print('(sizeof(int): ', out[0], ', sizeof(float): ', out[1], ', sizeof(long long): ', out[2], ') ', sep='')
......@@ -10,6 +10,7 @@ import ctypes
from six import string_types
import re
import theano
from theano.gof import utils
from theano.gof.utils import MethodNotDefined, object2
......@@ -809,3 +810,245 @@ class CDataTypeConstant(graph.Constant):
return (self.type,)
CDataType.Constant = CDataTypeConstant
class EnumType(Type, dict):
"""
Op parameter class that allows to create enumerations of constant values.
- Constants are available as object attributes in Python code and as macro-defined constants in C code.
- Constants can be floating values, integers, or booleans (automatically converted to integers).
- Constants name must start with a capital letter and contain capital letters, underscores or digits.
Example::
enum = EnumType(CONSTANT_1=1, CONSTANT_2=2.5, CONSTANT_3=False, CONSTANT_4=True)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4)
# will print 1 2.5 0 1
In C code:
.. code-block:: c
int constant_1 = CONSTANT_1;
double constant_2 = CONSTANT_2;
int constant_3 = CONSTANT_3; // constant_3 == 0
int constant_4 = CONSTANT_4; // constant_4 == 1
You can also specify a C type for the op param if you want to pass one of these constant values at runtime.
Default C type is ``double``.
.. code-block:: python
enum = EnumType(CONSTANT_1=0, CONSTANT_2=1, CONSTANT_3=2, ctype='size_t')
op_param_value = enum.CONSTANT_1
In C code:
.. code-block:: c
size_t value = op_param_value; // contains enum.CONSTANT_1, i.e 0
.. note::
This Type (and subclasses) is not complete and should never be used for regular graph operations.
"""
def check_ctype(self):
# C type may be a list of keywords, e.g. "unsigned long long".
# We should check each part.
if not all(re.match('^[A-Za-z_][A-Za-z0-9_]*$', el) for el in self.ctype.split()):
raise TypeError('%s: invalid C type' % type(self).__name__)
def __init__(self, **kwargs):
self.ctype = kwargs.pop('ctype', 'double')
self.check_ctype()
for k in kwargs:
if re.match('^[A-Z][A-Z0-9_]*$', k) is None:
raise AttributeError('%s: invalid enum name: "%s". '
'Only capital letters, underscores and digits '
'are allowed.' % (type(self).__name__, k))
if isinstance(kwargs[k], bool):
kwargs[k] = int(kwargs[k])
elif not isinstance(kwargs[k], (int, float)):
raise ValueError('%s: constant "%s": expected integer or floating value, got "%s".'
% (type(self).__name__, k, type(kwargs[k]).__name__))
super(EnumType, self).__init__(**kwargs)
def __repr__(self):
return '%s(%s)' % (type(self).__name__, ', '.join('%s:%s' % (k, self[k]) for k in sorted(self.keys())))
def __getattr__(self, key):
if key in self:
return self[key]
return Type.__getattr__(self, key)
def __setattr__(self, key, value):
if key in self:
raise NotImplementedError('constant values are immutable.')
Type.__setattr__(self, key, value)
def __setitem__(self, key, value):
raise NotImplementedError('constant values are immutable.')
def __delitem__(self, key):
raise NotImplementedError('constant values are immutable.')
def __hash__(self):
# All values are Python basic types, then easy to hash.
return hash((type(self), self.ctype) + tuple((k, self[k]) for k in sorted(self.keys())))
def __eq__(self, other):
return (type(self) == type(other) and
self.ctype == other.ctype and
len(self) == len(other) and
all(k in other for k in self) and
all(self[k] == other[k] for k in self))
# EnumType should be used to create constants available in both Python and C code.
# However, for convenience, we make sure EnumType can have a value, like other common types,
# such that it could be used as-is as an op param.
# C type of value is defined in self.ctype.
def filter(self, data, strict=False, allow_downcast=None):
if not strict and isinstance(data, bool):
data = int(data)
assert data in self.values()
return data
def values_eq(self, a, b):
return a == b
def values_eq_approx(self, a, b):
# For an enum, it does not have a meaning to be approx equal.
return self.values_eq(a, b)
pyint_compat_code = """
#if PY_MAJOR_VERSION >= 3
#ifndef PyInt_Check
#define PyInt_Check PyLong_Check
#endif
#ifndef PyInt_AsLong
#define PyInt_AsLong PyLong_AsLong
#endif
#endif
"""
def c_support_code(self):
return (
self.pyint_compat_code +
''.join("""
#define %s %s
""" % (k, str(self[k])) for k in sorted(self.keys()))
)
def c_declare(self, name, sub, check_input=True):
return """%(ctype)s %(name)s;""" % dict(ctype=self.ctype, name=name)
def c_init(self, name, sub):
return "%(name)s = 0;" % dict(name=name)
def c_cleanup(self, name, sub):
return ""
def c_extract(self, name, sub, check_input=True):
return """
if (PyInt_Check(py_%(name)s)) {
%(name)s = (%(ctype)s)PyInt_AsLong(py_%(name)s);
} else {
%(name)s = (%(ctype)s)PyFloat_AsDouble(py_%(name)s);
}
if (PyErr_Occurred()) {
%(fail)s
}
""" % dict(ctype=self.ctype, name=name, fail=sub['fail'])
def c_code_cache_version(self):
return (1,)
class EnumList(EnumType):
"""
Op parameter class that allows to create enumeration of constant values.
Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of
constants names (constant at index ``i`` in the list will receive value ``i``,
with ``i`` from ``0`` to ``len(constants) - 1``).
Example::
enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', 'CONSTANT_5')
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4, enum.CONSTANT_5)
# will print: 0 1 2 3 4
Like :class:`EnumType`, you can also define the C type for the op param.
Default C type is ``int``::
enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int')
See test class :class:`theano.gof.tests.test_types.TestOpEnumList` for a working example.
"""
def __init__(self, *args, **kwargs):
assert len(kwargs) == 0 or (len(kwargs) == 1 and 'ctype' in kwargs), \
type(self).__name__ + ': expected 0 or only 1 extra parameter "ctype".'
ctype = kwargs.pop('ctype', 'int')
if len(args) > len(set(args)):
raise AttributeError(type(self).__name__ + ': some constants names are duplicated.')
kwargs = {const_name: const_rank for (const_rank, const_name) in enumerate(args)}
kwargs.update(ctype=ctype)
super(EnumList, self).__init__(**kwargs)
class CEnumType(EnumList):
"""
Op parameter class that allows to create enumeration of constant values that represent C-defined constants.
- Constant should have same names as in C.
- In Python, constants will have arbitrary-defined values.
They should be used only for choices, not for its values.
- In C code, the real values defined in C will be used.
They could be used either for choices or for its real values.
Like :class:`EnumList`, you can also define the C type for the op param.
Default C type is ``int``.
.. code-block:: python
enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long')
See test class :class:`theano.gof.tests.test_types.TestOpCEnumType` for a working example.
.. note::
Be sure C constants are available in your C code. If they come from a C header, consider implementing
``c_headers()`` and ``c_header_dirs()`` in the Op class where you use CEnumType as op parameter type.
"""
def c_support_code(self):
return self.pyint_compat_code
def c_extract(self, name, sub, check_input=True):
swapped_dict = dict((v, k) for (k, v) in self.items())
# swapped_dict's keys are integers.
return """
switch(PyInt_AsLong(py_%(name)s)) {
%(cases)s
default:
PyErr_SetString(PyExc_ValueError, "CEnumType: invalid value to map to C constants.");
{%(fail)s}
break;
}
""" % dict(name=name,
cases=''.join("""
case %(i)d: %(name)s = %(constant_cname)s; break;
""" % dict(i=i, name=name, constant_cname=swapped_dict[i]) for i in sorted(swapped_dict.keys())),
fail=sub['fail'])
def c_code_cache_version(self):
return (1, super(CEnumType, self).c_code_cache_version())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论