提交 be97f3ee authored 作者: notoraptor's avatar notoraptor

Change struct naming, fix typos.

上级 8160671a
...@@ -28,6 +28,7 @@ See theano/common/tests/test_wrapper.py for a complete working example. ...@@ -28,6 +28,7 @@ See theano/common/tests/test_wrapper.py for a complete working example.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import re import re
import hashlib
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gof import Type from theano.gof import Type
from theano.gof.cmodule import GCC_compiler as compiler from theano.gof.cmodule import GCC_compiler as compiler
...@@ -101,7 +102,6 @@ class Wrapper(Type): ...@@ -101,7 +102,6 @@ class Wrapper(Type):
if len(kwargs) == 0: if len(kwargs) == 0:
raise ValueError('Cannot create Wrapper from empty data.') raise ValueError('Cannot create Wrapper from empty data.')
type_names = []
for attribute_name in kwargs: for attribute_name in kwargs:
if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None: if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None:
raise SyntaxError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name) raise SyntaxError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name)
...@@ -110,13 +110,11 @@ class Wrapper(Type): ...@@ -110,13 +110,11 @@ class Wrapper(Type):
if not isinstance(type_instance, Type): if not isinstance(type_instance, Type):
raise TypeError('Wrapper: attribute "%s" should inherit from theano Type, got "%s".' raise TypeError('Wrapper: attribute "%s" should inherit from theano Type, got "%s".'
% (attribute_name, type_name)) % (attribute_name, type_name))
type_names.append(type_name)
type_names.sort()
self.name = '_wrapper_struct_' + ('_'.join(type_names))
self.length = len(kwargs) self.length = len(kwargs)
self.fields = tuple(sorted(kwargs.keys())) self.fields = tuple(sorted(kwargs.keys()))
self.types = tuple(kwargs[field] for field in self.fields) self.types = tuple(kwargs[field] for field in self.fields)
self.name = self.generate_struct_name()
def __repr__(self): def __repr__(self):
return 'Wrapper<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)]) return 'Wrapper<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)])
...@@ -128,6 +126,19 @@ class Wrapper(Type): ...@@ -128,6 +126,19 @@ class Wrapper(Type):
def __hash__(self): def __hash__(self):
return hash((type(self),) + self.fields + self.types) return hash((type(self),) + self.fields + self.types)
def generate_struct_name(self):
""""
This method try to generate an unique name for the current instance.
This name is intended to be used as struct name in C code and
as constant definition to check if a similar Wrapper has already been created
(see c_support_code() below).
"""
fields_string = ','.join(self.fields)
types_string = ','.join(str(t) for t in self.types)
fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_struct_%s_%s' % (fields_hex, types_hex)
def check_that_values_are_compatible(self, data, strict, allow_downcast): def check_that_values_are_compatible(self, data, strict, allow_downcast):
wrap_instance = dict() wrap_instance = dict()
for i in range(self.length): for i in range(self.length):
...@@ -148,7 +159,7 @@ class Wrapper(Type): ...@@ -148,7 +159,7 @@ class Wrapper(Type):
wrap_instance = dict() wrap_instance = dict()
for i in range(self.length): for i in range(self.length):
if self.fields[i] not in data: if self.fields[i] not in data:
raise TypeError('%s expects a dictionary that has attribute "%s".' % (self, self.fields[i])) raise TypeError('%s: filter expects a dictionary that has attribute "%s".' % (self, self.fields[i]))
try: try:
wrap_instance[self.fields[i]] = self.types[i].filter(data[self.fields[i]], strict, allow_downcast) wrap_instance[self.fields[i]] = self.types[i].filter(data[self.fields[i]], strict, allow_downcast)
except Exception as e: except Exception as e:
...@@ -257,7 +268,7 @@ class Wrapper(Type): ...@@ -257,7 +268,7 @@ class Wrapper(Type):
def c_support_code(self): def c_support_code(self):
sub = {'fail': '{this->setErrorOccurred(); this->cleanup(); return;}'} sub = {'fail': '{this->setErrorOccurred(); this->cleanup(); return;}'}
struct_name = self.name struct_name = self.name
struct_name_defined = struct_name.upper() + '_DEFINED' struct_name_defined = struct_name.upper()
struct_fields = '' struct_fields = ''
struct_init = '' struct_init = ''
struct_cleanup = '' struct_cleanup = ''
...@@ -268,18 +279,22 @@ class Wrapper(Type): ...@@ -268,18 +279,22 @@ class Wrapper(Type):
c_extract_list = [] c_extract_list = []
for attribute_name, type_instance in zip(self.fields, self.types): for attribute_name, type_instance in zip(self.fields, self.types):
type_name = type_instance.__class__.__name__ type_name = type_instance.__class__.__name__
try: try:
c_declare_list.append(type_instance.c_declare(attribute_name, sub)) c_declare_list.append(type_instance.c_declare(attribute_name, sub))
except MethodNotDefined: except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_declare().' % type_name) raise RuntimeError('Wrapper: class "%s" should implement method Type.c_declare().' % type_name)
try: try:
c_init_list.append(type_instance.c_init(attribute_name, sub)) c_init_list.append(type_instance.c_init(attribute_name, sub))
except MethodNotDefined: except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_init().' % type_name) raise RuntimeError('Wrapper: class "%s" should implement method Type.c_init().' % type_name)
try: try:
c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub)) c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub))
except MethodNotDefined: except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_cleanup().' % type_name) raise RuntimeError('Wrapper: class "%s" should implement method Type.c_cleanup().' % type_name)
try: try:
c_extract_list.append(""" c_extract_list.append("""
void extract_%(attribute_name)s(PyObject* py_%(attribute_name)s) { void extract_%(attribute_name)s(PyObject* py_%(attribute_name)s) {
...@@ -291,6 +306,7 @@ class Wrapper(Type): ...@@ -291,6 +306,7 @@ class Wrapper(Type):
}) })
except MethodNotDefined: except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_extract().' % type_name) raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_extract().' % type_name)
struct_fields = '\n'.join(c_declare_list) struct_fields = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list) struct_init = '\n'.join(c_init_list)
struct_cleanup = '\n'.join(c_cleanup_list) struct_cleanup = '\n'.join(c_cleanup_list)
...@@ -339,6 +355,9 @@ class Wrapper(Type): ...@@ -339,6 +355,9 @@ class Wrapper(Type):
/* Extraction methods. */ /* Extraction methods. */
%(struct_extraction_methods)s %(struct_extraction_methods)s
/* Extract method. */
%(struct_extract_method)s
/* Other methods. */ /* Other methods. */
void setErrorOccurred() { void setErrorOccurred() {
++%(struct_name)s_error; ++%(struct_name)s_error;
...@@ -346,13 +365,12 @@ class Wrapper(Type): ...@@ -346,13 +365,12 @@ class Wrapper(Type):
int errorOccurred() { int errorOccurred() {
return %(struct_name)s_error; return %(struct_name)s_error;
} }
%(struct_extract_method)s
}; };
#endif #endif
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1, 1)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
struct_name = self.name struct_name = self.name
...@@ -373,17 +391,20 @@ class Wrapper(Type): ...@@ -373,17 +391,20 @@ class Wrapper(Type):
fail = sub['fail'] fail = sub['fail']
length = self.length length = self.length
fields_list = '"%s"' % '", "'.join(self.fields) fields_list = '"%s"' % '", "'.join(self.fields)
check = 1 if check_input else 0
return """ return """
const char* fields[] = {%(fields_list)s}; const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) { if (%(check)s) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None."); if (py_%(name)s == Py_None) {
%(fail)s PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
}
for (int i = 0; i < %(length)s; ++i) {
if (!PyObject_HasAttrString(py_%(name)s, fields[i])) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]);
%(fail)s %(fail)s
} }
for (int i = 0; i < %(length)s; ++i) {
if (!PyObject_HasAttrString(py_%(name)s, fields[i])) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]);
%(fail)s
}
}
} }
for (int i = 0; i < %(length)s; ++i) { for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]); PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论