提交 be97f3ee authored 作者: notoraptor's avatar notoraptor

Change struct naming, fix typos.

上级 8160671a
......@@ -28,6 +28,7 @@ See theano/common/tests/test_wrapper.py for a complete working example.
from __future__ import absolute_import, print_function, division
import re
import hashlib
from theano.gof.utils import MethodNotDefined
from theano.gof import Type
from theano.gof.cmodule import GCC_compiler as compiler
......@@ -101,7 +102,6 @@ class Wrapper(Type):
if len(kwargs) == 0:
raise ValueError('Cannot create Wrapper from empty data.')
type_names = []
for attribute_name in kwargs:
if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None:
raise SyntaxError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name)
......@@ -110,13 +110,11 @@ class Wrapper(Type):
if not isinstance(type_instance, Type):
raise TypeError('Wrapper: attribute "%s" should inherit from theano Type, got "%s".'
% (attribute_name, type_name))
type_names.append(type_name)
type_names.sort()
self.name = '_wrapper_struct_' + ('_'.join(type_names))
self.length = len(kwargs)
self.fields = tuple(sorted(kwargs.keys()))
self.types = tuple(kwargs[field] for field in self.fields)
self.name = self.generate_struct_name()
def __repr__(self):
return 'Wrapper<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)])
......@@ -128,6 +126,19 @@ class Wrapper(Type):
def __hash__(self):
return hash((type(self),) + self.fields + self.types)
def generate_struct_name(self):
""""
This method try to generate an unique name for the current instance.
This name is intended to be used as struct name in C code and
as constant definition to check if a similar Wrapper has already been created
(see c_support_code() below).
"""
fields_string = ','.join(self.fields)
types_string = ','.join(str(t) for t in self.types)
fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_struct_%s_%s' % (fields_hex, types_hex)
def check_that_values_are_compatible(self, data, strict, allow_downcast):
wrap_instance = dict()
for i in range(self.length):
......@@ -148,7 +159,7 @@ class Wrapper(Type):
wrap_instance = dict()
for i in range(self.length):
if self.fields[i] not in data:
raise TypeError('%s expects a dictionary that has attribute "%s".' % (self, self.fields[i]))
raise TypeError('%s: filter expects a dictionary that has attribute "%s".' % (self, self.fields[i]))
try:
wrap_instance[self.fields[i]] = self.types[i].filter(data[self.fields[i]], strict, allow_downcast)
except Exception as e:
......@@ -257,7 +268,7 @@ class Wrapper(Type):
def c_support_code(self):
sub = {'fail': '{this->setErrorOccurred(); this->cleanup(); return;}'}
struct_name = self.name
struct_name_defined = struct_name.upper() + '_DEFINED'
struct_name_defined = struct_name.upper()
struct_fields = ''
struct_init = ''
struct_cleanup = ''
......@@ -268,18 +279,22 @@ class Wrapper(Type):
c_extract_list = []
for attribute_name, type_instance in zip(self.fields, self.types):
type_name = type_instance.__class__.__name__
try:
c_declare_list.append(type_instance.c_declare(attribute_name, sub))
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_declare().' % type_name)
raise RuntimeError('Wrapper: class "%s" should implement method Type.c_declare().' % type_name)
try:
c_init_list.append(type_instance.c_init(attribute_name, sub))
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_init().' % type_name)
raise RuntimeError('Wrapper: class "%s" should implement method Type.c_init().' % type_name)
try:
c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub))
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_cleanup().' % type_name)
raise RuntimeError('Wrapper: class "%s" should implement method Type.c_cleanup().' % type_name)
try:
c_extract_list.append("""
void extract_%(attribute_name)s(PyObject* py_%(attribute_name)s) {
......@@ -291,6 +306,7 @@ class Wrapper(Type):
})
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_extract().' % type_name)
struct_fields = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list)
struct_cleanup = '\n'.join(c_cleanup_list)
......@@ -339,6 +355,9 @@ class Wrapper(Type):
/* Extraction methods. */
%(struct_extraction_methods)s
/* Extract method. */
%(struct_extract_method)s
/* Other methods. */
void setErrorOccurred() {
++%(struct_name)s_error;
......@@ -346,13 +365,12 @@ class Wrapper(Type):
int errorOccurred() {
return %(struct_name)s_error;
}
%(struct_extract_method)s
};
#endif
""" % locals()
def c_code_cache_version(self):
return (1,)
return (1, 1)
def c_declare(self, name, sub, check_input=True):
struct_name = self.name
......@@ -373,17 +391,20 @@ class Wrapper(Type):
fail = sub['fail']
length = self.length
fields_list = '"%s"' % '", "'.join(self.fields)
check = 1 if check_input else 0
return """
const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
%(fail)s
}
for (int i = 0; i < %(length)s; ++i) {
if (!PyObject_HasAttrString(py_%(name)s, fields[i])) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]);
if (%(check)s) {
if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
%(fail)s
}
for (int i = 0; i < %(length)s; ++i) {
if (!PyObject_HasAttrString(py_%(name)s, fields[i])) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]);
%(fail)s
}
}
}
for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论