提交 194290a8 authored 作者: notoraptor's avatar notoraptor

Extend enumeration types: add aliases.

Add public: fromalias(), has_alias(). Update documentation for enumeration types.
上级 00935985
...@@ -814,13 +814,20 @@ CDataType.Constant = CDataTypeConstant ...@@ -814,13 +814,20 @@ CDataType.Constant = CDataTypeConstant
class EnumType(Type, dict): class EnumType(Type, dict):
""" """
Main subclasses:
- :class:`EnumList`
- :class:`CEnumType`
Op parameter class that allows to create enumerations of constant values. Op parameter class that allows to create enumerations of constant values.
- Constants are available as object attributes in Python code and as macro-defined constants in C code. - Constants are available as object attributes in Python code and as macro-defined constants in C code.
- Constants can be floating values, integers, or booleans (automatically converted to integers). - Constants can be floating values, integers, or booleans (automatically converted to integers).
- Constants name must start with a capital letter and contain capital letters, underscores or digits. - Constants name must start with a capital letter and contain capital letters, underscores or digits.
- A constant can have an alias, and then be available through both constant name and constant alias.
Example:: **Example**
.. code-block:: python
enum = EnumType(CONSTANT_1=1, CONSTANT_2=2.5, CONSTANT_3=False, CONSTANT_4=True) enum = EnumType(CONSTANT_1=1, CONSTANT_2=2.5, CONSTANT_3=False, CONSTANT_4=True)
print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4) print (enum.CONSTANT_1, enum.CONSTANT_2, enum.CONSTANT_3, enum.CONSTANT_4)
...@@ -849,35 +856,106 @@ class EnumType(Type, dict): ...@@ -849,35 +856,106 @@ class EnumType(Type, dict):
size_t value = op_param_value; // contains enum.CONSTANT_1, i.e 0 size_t value = op_param_value; // contains enum.CONSTANT_1, i.e 0
**Example with aliases**
When creating an enum, you can give some aliases to specific constants while keeping other constants without aliases.
An alias must be a string, and there is currently no string format constraints.
To give an alias to a constant in the EnumType constructor, use the following key-value syntax::
constant_name=(constant_alias, constant_value)
You can then retrieve a constant from an alias with method ``EnumType.fromalias()``.
Aliases are intended to be used in Python code only (only constants names are available in C code).
Especially, an alias will be recognized by ``Enumtype.filter()`` method with non-strict filtering,
allowing a maximum flexibility for converting strings to numeric constants available in Python and C code.
.. code-block:: python
from theano.gof import EnumType
# You can remark that constant 'C' does not have an alias.
enum = EnumType(A=('alpha', 1), B=('beta', 2), C=3, D=('delta', 4))
# Constants are all directly available by name.
print(enum.A, enum.B, enum.C, enum.D)
# But we can also now get some constants by alias.
a = enum.fromalias('alpha')
b = enum.fromalias('beta')
d = enum.fromalias('delta')
# If method fromalias() receives an unknown alias,
# it will looks for a constant with this alias
# as exact constant name.
c = enum.fromalias('C') # will get enum.C
# An alias defined in an EnumType will be correctly converted with non-strict filtering.
value = enum.filter('delta', strict=False)
# value now contaisn enum.D, ie. 4.
.. note:: .. note::
This Type (and subclasses) is not complete and should never be used for regular graph operations. This Type (and subclasses) is not complete and should never be used for regular graph operations.
""" """
def check_ctype(self): def __init_ctype(self, ctype):
# C type may be a list of keywords, e.g. "unsigned long long". # C type may be a list of keywords, e.g. "unsigned long long".
# We should check each part. # We should check each part.
if not all(re.match('^[A-Za-z_][A-Za-z0-9_]*$', el) for el in self.ctype.split()): ctype_parts = ctype.split()
raise TypeError('%s: invalid C type' % type(self).__name__) if not all(re.match('^[A-Za-z_][A-Za-z0-9_]*$', el) for el in ctype_parts):
raise TypeError('%s: invalid C type.' % type(self).__name__)
self.ctype = ' '.join(ctype_parts)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.ctype = kwargs.pop('ctype', 'double') self.__init_ctype(kwargs.pop('ctype', 'double'))
self.check_ctype() self.aliases = dict()
for k in kwargs: for k in kwargs:
if re.match('^[A-Z][A-Z0-9_]*$', k) is None: if re.match('^[A-Z][A-Z0-9_]*$', k) is None:
raise AttributeError('%s: invalid enum name: "%s". ' raise AttributeError('%s: invalid enum name: "%s". '
'Only capital letters, underscores and digits ' 'Only capital letters, underscores and digits '
'are allowed.' % (type(self).__name__, k)) 'are allowed.' % (type(self).__name__, k))
if isinstance(kwargs[k], (list, tuple)):
if len(kwargs[k]) != 2:
raise TypeError('%s: when using a tuple to define a constant, your tuple should contain 2 values: '
'constant alias followed by constant value.' % type(self).__name__)
alias, value = kwargs[k]
if not isinstance(alias, str):
raise TypeError('%s: constant alias should be a string, got "%s".'
% (type(self).__name__, alias))
if alias in self.aliases:
raise TypeError('%s: consant alias "%s" already used.' % (type(self).__name__, alias))
self.aliases[alias] = k
kwargs[k] = value
if isinstance(kwargs[k], bool): if isinstance(kwargs[k], bool):
kwargs[k] = int(kwargs[k]) kwargs[k] = int(kwargs[k])
elif not isinstance(kwargs[k], (int, float)): elif not isinstance(kwargs[k], (int, float)):
raise ValueError('%s: constant "%s": expected integer or floating value, got "%s".' raise TypeError('%s: constant "%s": expected integer or floating value, got "%s".'
% (type(self).__name__, k, type(kwargs[k]).__name__)) % (type(self).__name__, k, type(kwargs[k]).__name__))
super(EnumType, self).__init__(**kwargs) super(EnumType, self).__init__(**kwargs)
def fromalias(self, alias):
"""
Get a constant value by its alias.
If there is not such alias in this enum, look for a constant
with this alias as constant name.
"""
return self[self.aliases[alias]] if alias in self.aliases else self[alias]
def has_alias(self, alias):
"""
return True if and only if this enum has this alias.
"""
return alias in self.aliases
def __repr__(self): def __repr__(self):
return '%s(%s)' % (type(self).__name__, ', '.join('%s:%s' % (k, self[k]) for k in sorted(self.keys()))) names_to_aliases = {constant_name: '' for constant_name in self}
for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = '(%s)' % alias
return '%s<%s>(%s)' % (type(self).__name__, self.ctype,
', '.join('%s%s:%s' % (k, names_to_aliases[k], self[k]) for k in sorted(self.keys())))
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self:
...@@ -897,14 +975,19 @@ class EnumType(Type, dict): ...@@ -897,14 +975,19 @@ class EnumType(Type, dict):
def __hash__(self): def __hash__(self):
# All values are Python basic types, then easy to hash. # All values are Python basic types, then easy to hash.
return hash((type(self), self.ctype) + tuple((k, self[k]) for k in sorted(self.keys()))) return hash((type(self), self.ctype) +
tuple((k, self[k]) for k in sorted(self.keys())) +
tuple((a, self.aliases[a]) for a in sorted(self.aliases.keys())))
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (type(self) == type(other) and
self.ctype == other.ctype and self.ctype == other.ctype and
len(self) == len(other) and len(self) == len(other) and
len(self.aliases) == len(other.aliases) and
all(k in other for k in self) and all(k in other for k in self) and
all(self[k] == other[k] for k in self)) all(a in other.aliases for a in self.aliases) and
all(self[k] == other[k] for k in self) and
all(self.aliases[a] == other.aliases[a] for a in self.aliases))
# EnumType should be used to create constants available in both Python and C code. # EnumType should be used to create constants available in both Python and C code.
# However, for convenience, we make sure EnumType can have a value, like other common types, # However, for convenience, we make sure EnumType can have a value, like other common types,
...@@ -912,8 +995,13 @@ class EnumType(Type, dict): ...@@ -912,8 +995,13 @@ class EnumType(Type, dict):
# C type of value is defined in self.ctype. # C type of value is defined in self.ctype.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if not strict and isinstance(data, bool): if not strict:
data = int(data) if isinstance(data, bool):
data = int(data)
elif isinstance(data, str):
# We now accept strings as data values.
# Strings should be a constant alias or a constant name.
data = self.fromalias(data)
assert data in self.values() assert data in self.values()
return data return data
...@@ -947,7 +1035,7 @@ class EnumType(Type, dict): ...@@ -947,7 +1035,7 @@ class EnumType(Type, dict):
return """%(ctype)s %(name)s;""" % dict(ctype=self.ctype, name=name) return """%(ctype)s %(name)s;""" % dict(ctype=self.ctype, name=name)
def c_init(self, name, sub): def c_init(self, name, sub):
return "%(name)s = 0;" % dict(name=name) return "%(name)s = (%(ctype)s)0;" % dict(name=name, ctype=self.ctype)
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return ""
...@@ -965,11 +1053,14 @@ class EnumType(Type, dict): ...@@ -965,11 +1053,14 @@ class EnumType(Type, dict):
""" % dict(ctype=self.ctype, name=name, fail=sub['fail']) """ % dict(ctype=self.ctype, name=name, fail=sub['fail'])
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1, 1)
class EnumList(EnumType): class EnumList(EnumType):
""" """
**Inherit from**:
- :class:`EnumType`
Op parameter class that allows to create enumeration of constant values. Op parameter class that allows to create enumeration of constant values.
Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of Same as :class:`EnumType`, but automatically gives an unique integer value for each constant in a list of
constants names (constant at index ``i`` in the list will receive value ``i``, constants names (constant at index ``i`` in the list will receive value ``i``,
...@@ -986,6 +1077,14 @@ class EnumList(EnumType): ...@@ -986,6 +1077,14 @@ class EnumList(EnumType):
enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int') enum = EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3', 'CONSTANT_4', ctype='unsigned int')
Like :class:`EnumType`, you can also add an alias to a constant, by replacing the only constant name
(e.g. ``'CONSTANT_NAME'``) by a couple with constant name first and constant alias second
(e.g. ``('CONSTANT_NAME', 'constant_alias')``).
.. code-block:: python
enum = EnumList(('A', 'alpha'), ('B', 'beta'), 'C', 'D', 'E', 'F', ('G', 'gamma'))
See test class :class:`theano.gof.tests.test_types.TestOpEnumList` for a working example. See test class :class:`theano.gof.tests.test_types.TestOpEnumList` for a working example.
""" """
...@@ -995,16 +1094,35 @@ class EnumList(EnumType): ...@@ -995,16 +1094,35 @@ class EnumList(EnumType):
type(self).__name__ + ': expected 0 or only 1 extra parameter "ctype".' type(self).__name__ + ': expected 0 or only 1 extra parameter "ctype".'
ctype = kwargs.pop('ctype', 'int') ctype = kwargs.pop('ctype', 'int')
if len(args) > len(set(args)): for arg_rank, arg in enumerate(args):
raise AttributeError(type(self).__name__ + ': some constants names are duplicated.') if isinstance(arg, (list, tuple)):
if len(arg) != 2:
raise TypeError('%s: when using a tuple to define a constant, your tuple should contain 2 values: '
'constant name followed by constant alias.' % type(self).__name__)
constant_name, constant_alias = arg
if not isinstance(constant_alias, str):
raise TypeError('%s: constant alias should be a string, got "%s".'
% (type(self).__name__, constant_alias))
constant_value = (constant_alias, arg_rank)
else:
constant_name = arg
constant_value = arg_rank
if not isinstance(constant_name, str):
raise TypeError('%s: constant name should be a string, got "%s".'
% (type(self).__name__, constant_name))
if constant_name in kwargs:
raise TypeError('%s: constant name already used ("%s").' % (type(self).__name__, constant_name))
kwargs[constant_name] = constant_value
kwargs = {const_name: const_rank for (const_rank, const_name) in enumerate(args)}
kwargs.update(ctype=ctype) kwargs.update(ctype=ctype)
super(EnumList, self).__init__(**kwargs) super(EnumList, self).__init__(**kwargs)
class CEnumType(EnumList): class CEnumType(EnumList):
""" """
**Inherit from**:
- :class:`EnumList`
Op parameter class that allows to create enumeration of constant values that represent C-defined constants. Op parameter class that allows to create enumeration of constant values that represent C-defined constants.
- Constant should have same names as in C. - Constant should have same names as in C.
...@@ -1020,6 +1138,8 @@ class CEnumType(EnumList): ...@@ -1020,6 +1138,8 @@ class CEnumType(EnumList):
enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long') enum = CEnumType('CONSTANT_CNAME_1', 'CONSTANT_CNAME_2', 'CONSTANT_CNAME_3', ctype='long')
Like :class:`EnumList`, you can also add an alias to a constant, with same syntax as in :class:`EnumList`.
See test class :class:`theano.gof.tests.test_types.TestOpCEnumType` for a working example. See test class :class:`theano.gof.tests.test_types.TestOpCEnumType` for a working example.
.. note:: .. note::
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论