提交 adaa5f57 authored 作者: notoraptor's avatar notoraptor

Final re-reading.

Modify `Wrapper.value_eq` and `Wrapper.value_eq_approx`. Fix typos, rename variables, clarify doc strings and comments.
上级 feedbb19
...@@ -17,7 +17,7 @@ scalar_type = Scalar(dtype) ...@@ -17,7 +17,7 @@ scalar_type = Scalar(dtype)
generic_type = Generic() generic_type = Generic()
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as parameters of that op. # A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params.
class QuadraticFunction(Op): class QuadraticFunction(Op):
__props__ = ('a', 'b', 'c') __props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d, params_type = Wrapper(a=tensor_type_0d,
......
""" """
Module for wrapping many Theano variables into one struct for param ops. Module for wrapping many Theano variables into one struct for op params.
This module contains two classes: This module contains two classes:
- Wrapper: class to define the param op type. - Wrapper: class to define the op params type.
- Wrap: convenient class to create an object that is compatible with the param op type. - Wrap: internal convenient class to create an object that is compatible with a Wrapper-defined op params.
Example of usage: Example of usage:
...@@ -34,10 +34,10 @@ Example of usage: ...@@ -34,10 +34,10 @@ Example of usage:
PyArrayObject* attr1 = param.attr1; PyArrayObject* attr1 = param.attr1;
PyArrayObject* attr2 = param.attr2; PyArrayObject* attr2 = param.attr2;
/* Just use attr1 and attr2, you won't need to free them or whatever else. */ /* You won't need to free them or whatever else. */
See theano/gof/tests/test_wrapper.py for a complete working example. See `theano/gof/tests/test_wrapper.py` for a complete working example.
""" """
...@@ -61,7 +61,7 @@ class Wrap(object): ...@@ -61,7 +61,7 @@ class Wrap(object):
(this class is not safe as the hash method does not check if values are effectively hashable). (this class is not safe as the hash method does not check if values are effectively hashable).
Example: Example:
>>> w = Wrap(attr1=var1, attr2=var2, attri=vari) >>> w = Wrap(attr1=1, attr2=2.0, attri='3')
>>> print(w.attr1, w.attr2, w.attri) >>> print(w.attr1, w.attr2, w.attri)
>>> d = dict(a=1, b=2, c='test') >>> d = dict(a=1, b=2, c='test')
>>> w2 = Wrap(**d) >>> w2 = Wrap(**d)
...@@ -119,7 +119,6 @@ class Wrap(object): ...@@ -119,7 +119,6 @@ class Wrap(object):
elif self.data[k] != other.data[k]: elif self.data[k] != other.data[k]:
return False return False
return True return True
# return type(self) == type(other) and self.data == other.data
class Wrapper(Type): class Wrapper(Type):
...@@ -127,9 +126,9 @@ class Wrapper(Type): ...@@ -127,9 +126,9 @@ class Wrapper(Type):
This class can create a struct of Theano types (like TensorType, GpuArrayType, etc.) This class can create a struct of Theano types (like TensorType, GpuArrayType, etc.)
to be used as a convenience op parameter wrapping many data. to be used as a convenience op parameter wrapping many data.
Wrapper constructor takes many key-value args. Wrapper constructor takes key-value args.
Key will be the name of the attribute in the struct. Key will be the name of the attribute in the struct.
Value is the Theano type of this attribute, that is an instance of (a subclass of) Type Value is the Theano type of this attribute, ie. an instance of (a subclass of) Type
(eg. TensorType('int64', (False,))). (eg. TensorType('int64', (False,))).
In a Python code any attribute named `key` will be available via: In a Python code any attribute named `key` will be available via:
...@@ -182,13 +181,13 @@ class Wrapper(Type): ...@@ -182,13 +181,13 @@ class Wrapper(Type):
types_string = ','.join(str(t) for t in self.types).encode('utf-8') types_string = ','.join(str(t) for t in self.types).encode('utf-8')
fields_hex = hashlib.md5(fields_string).hexdigest() fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest() types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_struct_%s_%s' % (fields_hex, types_hex) return '_wrapper_%s_%s' % (fields_hex, types_hex)
def check_that_values_are_compatible(self, data, strict, allow_downcast): def check_that_values_are_compatible(self, data, strict, allow_downcast):
wrap_instance = dict() wrap_instance = dict()
for i in range(self.length): for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast) wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return Wrap(**wrap_instance) return data if strict else Wrap(**wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes. # Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
...@@ -199,14 +198,20 @@ class Wrapper(Type): ...@@ -199,14 +198,20 @@ class Wrapper(Type):
return self.check_that_values_are_compatible(data, strict, allow_downcast) return self.check_that_values_are_compatible(data, strict, allow_downcast)
def values_eq(self, a, b): def values_eq(self, a, b):
# We check that a and b have expected attributes and strict values.
a = self.filter(a, strict=True)
b = self.filter(b, strict=True)
# Then we compare.
for i in range(self.length): for i in range(self.length):
if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])): if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False return False
return True return True
def values_eq_approx(self, a, b): def values_eq_approx(self, a, b):
# We check, wrap and round a and b if necessary.
a = self.filter(a, strict=False, allow_downcast=True) a = self.filter(a, strict=False, allow_downcast=True)
b = self.filter(b, strict=False, allow_downcast=True) b = self.filter(b, strict=False, allow_downcast=True)
# Then we compare.
for i in range(self.length): for i in range(self.length):
if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])): if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False return False
...@@ -291,16 +296,15 @@ class Wrapper(Type): ...@@ -291,16 +296,15 @@ class Wrapper(Type):
sub = {'fail': '{this->setErrorOccurred(); this->cleanup(); return;}'} sub = {'fail': '{this->setErrorOccurred(); this->cleanup(); return;}'}
struct_name = self.name struct_name = self.name
struct_name_defined = struct_name.upper() struct_name_defined = struct_name.upper()
struct_fields = '' struct_declare = ''
struct_init = '' struct_init = ''
struct_cleanup = '' struct_cleanup = ''
struct_extraction_methods = '' struct_extract = ''
c_declare_list = [] c_declare_list = []
c_init_list = [] c_init_list = []
c_cleanup_list = [] c_cleanup_list = []
c_extract_list = [] c_extract_list = []
for attribute_name, type_instance in zip(self.fields, self.types): for attribute_name, type_instance in zip(self.fields, self.types):
type_name = type_instance.__class__.__name__
c_declare_list.append(type_instance.c_declare(attribute_name, sub)) c_declare_list.append(type_instance.c_declare(attribute_name, sub))
...@@ -317,10 +321,10 @@ class Wrapper(Type): ...@@ -317,10 +321,10 @@ class Wrapper(Type):
'extract_code': type_instance.c_extract(attribute_name, sub) 'extract_code': type_instance.c_extract(attribute_name, sub)
}) })
struct_fields = '\n'.join(c_declare_list) struct_declare = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list) struct_init = '\n'.join(c_init_list)
struct_cleanup = '\n'.join(c_cleanup_list) struct_cleanup = '\n'.join(c_cleanup_list)
struct_extraction_methods = '\n\n'.join(c_extract_list) struct_extract = '\n\n'.join(c_extract_list)
struct_extract_method = """ struct_extract_method = """
void extract(PyObject* object, int field_pos) { void extract(PyObject* object, int field_pos) {
switch(field_pos) { switch(field_pos) {
...@@ -343,7 +347,7 @@ class Wrapper(Type): ...@@ -343,7 +347,7 @@ class Wrapper(Type):
struct %(struct_name)s { struct %(struct_name)s {
/* Attributes, */ /* Attributes, */
int %(struct_name)s_error; int %(struct_name)s_error;
%(struct_fields)s %(struct_declare)s
/* Constructor. */ /* Constructor. */
%(struct_name)s() { %(struct_name)s() {
...@@ -363,7 +367,7 @@ class Wrapper(Type): ...@@ -363,7 +367,7 @@ class Wrapper(Type):
} }
/* Extraction methods. */ /* Extraction methods. */
%(struct_extraction_methods)s %(struct_extract)s
/* Extract method. */ /* Extract method. */
%(struct_extract_method)s %(struct_extract_method)s
...@@ -380,7 +384,7 @@ class Wrapper(Type): ...@@ -380,7 +384,7 @@ class Wrapper(Type):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, 2) return (1, 3)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
struct_name = self.name struct_name = self.name
...@@ -388,8 +392,8 @@ class Wrapper(Type): ...@@ -388,8 +392,8 @@ class Wrapper(Type):
%(struct_name)s %(name)s; %(struct_name)s %(name)s;
""" % locals() """ % locals()
# c_init() and c_cleanup() are useless if we create the struct on stack # c_init() and c_cleanup() are useless if we create the struct
# because the struct has constructor and destructor. # on stack, as struct class has constructor and destructor.
def c_init(self, name, sub): def c_init(self, name, sub):
return "" return ""
...@@ -401,7 +405,6 @@ class Wrapper(Type): ...@@ -401,7 +405,6 @@ class Wrapper(Type):
fail = sub['fail'] fail = sub['fail']
length = self.length length = self.length
fields_list = '"%s"' % '", "'.join(self.fields) fields_list = '"%s"' % '", "'.join(self.fields)
check = 1 if check_input else 0
return """ return """
const char* fields[] = {%(fields_list)s}; const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
...@@ -411,7 +414,7 @@ class Wrapper(Type): ...@@ -411,7 +414,7 @@ class Wrapper(Type):
for (int i = 0; i < %(length)s; ++i) { for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]); PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]);
if (o == NULL) { if (o == NULL) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]); PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s %(fail)s
} }
%(name)s.extract(o, i); %(name)s.extract(o, i);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论