提交 feedbb19 authored 作者: notoraptor's avatar notoraptor

Simplify and secure code.

Let exceptions raise freely (for better debugging).
上级 72eeee09
......@@ -797,7 +797,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
"""
# We add a default get_params() implementation which will try to detect params from the op
# if params_type is set to a Wrapper. If not, we raise a MethodNodDefined exception.
# if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception.
def get_params(self, node):
if hasattr(self, 'params_type'):
# If params_type is a Wrapper, we try to extract params from the op.
......
......@@ -6,23 +6,38 @@ This module contains two classes:
- Wrap: convenient class to create an object that is compatible with the param op type.
Example of usage:
# Importation
>>> from theano.common import Wrapper, Wrap
# In a op you create:
>>> params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=TensorType('float64', (True,False)))
# In the get_params() method of your op:
>>> return Wrap(attr1=numpyArray1, attr2=numpyArray2)
# In perform() implementation (with params named `param`):
>>> print(param.attr1)
>>> print(param.attr2)
# In c_code() implementation (with `param = sub['params']`):
```
Importation:
from theano.gof.wrapper import Wrapper
In an op you create:
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=TensorType('float64', (True, False)))
If your op contains props `attr1` AND `attr2`, the op.get_params() method will
automatically try to look for it and generate an appropriate wrapped struct.
The props must be able to pass the filtering (not strict, downcasting allowed)
of corresponding types defined into Wrapper.
__props__ = ('attr1', 'attr2')
def __init__(value_attr1, value_attr2):
self.attr1 = value_attr1
self.attr2 = value_attr2
In perform() implementation (with params named `param`):
var1 = param.attr1
var2 = param.attr2
In c_code() implementation (with `param = sub['params']`):
PyArrayObject* attr1 = param.attr1;
PyArrayObject* attr2 = param.attr2;
/* Just use attr1 and attr2, you won't need to free them or whatever else. */
```
See theano/common/tests/test_wrapper.py for a complete working example.
See theano/gof/tests/test_wrapper.py for a complete working example.
"""
......@@ -32,7 +47,6 @@ import hashlib
import numpy
from theano.gof.utils import MethodNotDefined
from theano.gof import Type
from theano.gof.cmodule import GCC_compiler as compiler
from theano.tensor.utils import hash_from_ndarray
# NB: Maybe we should check if an attribute name is a C/C++ keyword, and raise an error if so.
......@@ -178,32 +192,11 @@ class Wrapper(Type):
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
if strict:
try:
self.check_that_values_are_compatible(data, strict, allow_downcast)
except AttributeError as e:
raise TypeError('%s: strict mode: missing expected attribute in filtered data:\n%s' % (self, e))
except Exception as e:
raise TypeError('%s: strict mode: a data does not pass corresponding type filtering:\n%s' % (self, e))
return data
elif isinstance(data, dict):
wrap_instance = dict()
for i in range(self.length):
if self.fields[i] not in data:
raise TypeError('%s: filter expects a dictionary that has attribute "%s".' % (self, self.fields[i]))
try:
wrap_instance[self.fields[i]] = self.types[i].filter(data[self.fields[i]], strict, allow_downcast)
except Exception as e:
raise TypeError('%s: a data does not pass filtering for attribute "%s":\n%s' % (self, self.fields[i], e))
return Wrap(**wrap_instance)
else:
try:
wrapped_data = self.check_that_values_are_compatible(data, strict, allow_downcast)
except AttributeError as e:
raise TypeError('%s: missing expected attribute in filtered data:\n%s' % (self, e))
except Exception as e:
raise TypeError('%s: a data does not pass corresponding type filtering:\n%s' % (self, e))
return wrapped_data
if isinstance(data, dict):
if strict:
raise TypeError('%s: strict mode: data should be an object, not a dict.' % self)
data = Wrap(**data)
return self.check_that_values_are_compatible(data, strict, allow_downcast)
def values_eq(self, a, b):
for i in range(self.length):
......@@ -219,50 +212,50 @@ class Wrapper(Type):
return False
return True
def c_compile_args(self):
def c_compile_args(self, c_compiler):
c_compile_args_list = []
for _type in self.types:
try:
try:
c_compile_args_list.extend(_type.c_compile_args())
c_compile_args_list.extend(_type.c_compile_args(c_compiler))
except TypeError:
c_compile_args_list.extend(_type.c_compile_args(compiler))
c_compile_args_list.extend(_type.c_compile_args())
except MethodNotDefined:
pass
return c_compile_args_list
def c_no_compile_args(self):
def c_no_compile_args(self, c_compiler):
c_no_compile_args_list = []
for _type in self.types:
try:
try:
c_no_compile_args_list.extend(_type.c_no_compile_args())
c_no_compile_args_list.extend(_type.c_no_compile_args(c_compiler))
except TypeError:
c_no_compile_args_list.extend(_type.c_no_compile_args(compiler))
c_no_compile_args_list.extend(_type.c_no_compile_args())
except MethodNotDefined:
pass
return c_no_compile_args_list
def c_headers(self):
def c_headers(self, c_compiler):
c_headers_list = []
for _type in self.types:
try:
try:
c_headers_list.extend(_type.c_headers())
c_headers_list.extend(_type.c_headers(c_compiler))
except TypeError:
c_headers_list.extend(_type.c_headers(compiler))
c_headers_list.extend(_type.c_headers())
except MethodNotDefined:
pass
return c_headers_list
def c_libraries(self):
def c_libraries(self, c_compiler):
c_libraries_list = []
for _type in self.types:
try:
try:
c_libraries_list.extend(_type.c_libraries())
c_libraries_list.extend(_type.c_libraries(c_compiler))
except TypeError:
c_libraries_list.extend(_type.c_libraries(compiler))
c_libraries_list.extend(_type.c_libraries())
except MethodNotDefined:
pass
return c_libraries_list
......@@ -309,32 +302,20 @@ class Wrapper(Type):
for attribute_name, type_instance in zip(self.fields, self.types):
type_name = type_instance.__class__.__name__
try:
c_declare_list.append(type_instance.c_declare(attribute_name, sub))
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement method Type.c_declare().' % type_name)
c_declare_list.append(type_instance.c_declare(attribute_name, sub))
try:
c_init_list.append(type_instance.c_init(attribute_name, sub))
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement method Type.c_init().' % type_name)
c_init_list.append(type_instance.c_init(attribute_name, sub))
try:
c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub))
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement method Type.c_cleanup().' % type_name)
c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub))
try:
c_extract_list.append("""
void extract_%(attribute_name)s(PyObject* py_%(attribute_name)s) {
%(extract_code)s
}
""" % {
'attribute_name': attribute_name,
'extract_code': type_instance.c_extract(attribute_name, sub)
})
except MethodNotDefined:
raise RuntimeError('Wrapper: class "%s" should implement the method Type.c_extract().' % type_name)
c_extract_list.append("""
void extract_%(attribute_name)s(PyObject* py_%(attribute_name)s) {
%(extract_code)s
}
""" % {
'attribute_name': attribute_name,
'extract_code': type_instance.c_extract(attribute_name, sub)
})
struct_fields = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list)
......@@ -399,7 +380,7 @@ class Wrapper(Type):
""" % locals()
def c_code_cache_version(self):
return (1, 1)
return (1, 2)
def c_declare(self, name, sub, check_input=True):
struct_name = self.name
......@@ -423,23 +404,21 @@ class Wrapper(Type):
check = 1 if check_input else 0
return """
const char* fields[] = {%(fields_list)s};
if (%(check)s) {
if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
%(fail)s
}
for (int i = 0; i < %(length)s; ++i) {
if (!PyObject_HasAttrString(py_%(name)s, fields[i])) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]);
%(fail)s
}
}
if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
%(fail)s
}
for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]);
if (o == NULL) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]);
%(fail)s
}
%(name)s.extract(o, i);
if (%(name)s.errorOccurred()) {
PyErr_Format(PyExc_ValueError, "Wrapper: error when extracting value for attribute \\"%%s\\".", fields[i]);
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
fprintf(stderr, "\\nWrapper: error when extracting value for attribute \\"%%s\\".\\n", fields[i]);
%(fail)s
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论