提交 8ef62afb authored 作者: notoraptor's avatar notoraptor

Make CEnumType complete so that it can be used in Theano graphs.

上级 c675144a
...@@ -909,7 +909,11 @@ class EnumType(Type, dict): ...@@ -909,7 +909,11 @@ class EnumType(Type, dict):
.. note:: .. note::
This Type (and subclasses) is not complete and should never be used for regular graph operations. :class:`EnumType` is not complete and should never be used for regular graph operations.
:class:`EnumList` is not complete and should never be used for regular graph operations.
**:class:`CEnumType` is complete.**
""" """
...@@ -975,9 +979,9 @@ class EnumType(Type, dict): ...@@ -975,9 +979,9 @@ class EnumType(Type, dict):
def get_aliases(self): def get_aliases(self):
""" """
Return the list of all aliases in this enumeration. Return the sorted tuple of all aliases in this enumeration.
""" """
return self.aliases.keys() return tuple(sorted(self.aliases.keys()))
def __repr__(self): def __repr__(self):
names_to_aliases = {constant_name: '' for constant_name in self} names_to_aliases = {constant_name: '' for constant_name in self}
...@@ -1049,6 +1053,9 @@ class EnumType(Type, dict): ...@@ -1049,6 +1053,9 @@ class EnumType(Type, dict):
#ifndef PyInt_AsLong #ifndef PyInt_AsLong
#define PyInt_AsLong PyLong_AsLong #define PyInt_AsLong PyLong_AsLong
#endif #endif
#ifndef PyInt_FromLong
#define PyInt_FromLong PyLong_FromLong
#endif
#endif #endif
""" """
...@@ -1241,5 +1248,22 @@ class CEnumType(EnumList): ...@@ -1241,5 +1248,22 @@ class CEnumType(EnumList):
""" % dict(i=i, name=name, constant_cname=swapped_dict[i]) for i in sorted(swapped_dict.keys())), """ % dict(i=i, name=name, constant_cname=swapped_dict[i]) for i in sorted(swapped_dict.keys())),
fail=sub['fail']) fail=sub['fail'])
def c_sync(self, name, sub):
return """
int py_value = -1;
Py_XDECREF(py_%(name)s);
/* We assume that ctype is an integer type usable in a switch. */
switch (%(name)s) {
%(cases)s
default:
PyErr_SetString(PyExc_ValueError, "CEnumType: cannot map C value to Python constant.");
{%(fail)s}
break;
}
py_%(name)s = PyInt_FromLong(py_value);
""" % dict(name=name, fail=sub['fail'], cases=''.join("""
case %(constant_cname)s: py_value = %(constant_pyvalue)d; break;
""" % dict(constant_cname=k, constant_pyvalue=v) for k, v in sorted(self.items(), key=lambda t: t[1])))
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, super(CEnumType, self).c_code_cache_version()) return (1, super(CEnumType, self).c_code_cache_version())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论