提交 adaa5f57 authored 作者: notoraptor's avatar notoraptor

Final re-reading.

Modify `Wrapper.value_eq` and `Wrapper.value_eq_approx`. Fix typos, rename variables, clarify doc strings and comments.
上级 feedbb19
......@@ -17,7 +17,7 @@ scalar_type = Scalar(dtype)
generic_type = Generic()
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as parameters of that op.
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params.
class QuadraticFunction(Op):
__props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d,
......
"""
Module for wrapping many Theano variables into one struct for param ops.
Module for wrapping many Theano variables into one struct for op params.
This module contains two classes:
- Wrapper: class to define the param op type.
- Wrap: convenient class to create an object that is compatible with the param op type.
- Wrapper: class to define the op params type.
- Wrap: internal convenient class to create an object that is compatible with a Wrapper-defined op params.
Example of usage:
......@@ -34,10 +34,10 @@ Example of usage:
PyArrayObject* attr1 = param.attr1;
PyArrayObject* attr2 = param.attr2;
/* Just use attr1 and attr2, you won't need to free them or whatever else. */
/* You won't need to free them or whatever else. */
See theano/gof/tests/test_wrapper.py for a complete working example.
See `theano/gof/tests/test_wrapper.py` for a complete working example.
"""
......@@ -61,7 +61,7 @@ class Wrap(object):
(this class is not safe as the hash method does not check if values are effectively hashable).
Example:
>>> w = Wrap(attr1=var1, attr2=var2, attri=vari)
>>> w = Wrap(attr1=1, attr2=2.0, attri='3')
>>> print(w.attr1, w.attr2, w.attri)
>>> d = dict(a=1, b=2, c='test')
>>> w2 = Wrap(**d)
......@@ -119,7 +119,6 @@ class Wrap(object):
elif self.data[k] != other.data[k]:
return False
return True
# return type(self) == type(other) and self.data == other.data
class Wrapper(Type):
......@@ -127,9 +126,9 @@ class Wrapper(Type):
This class can create a struct of Theano types (like TensorType, GpuArrayType, etc.)
to be used as a convenience op parameter wrapping many data.
Wrapper constructor takes many key-value args.
Wrapper constructor takes key-value args.
Key will be the name of the attribute in the struct.
Value is the Theano type of this attribute, that is an instance of (a subclass of) Type
Value is the Theano type of this attribute, ie. an instance of (a subclass of) Type
(eg. TensorType('int64', (False,))).
In a Python code any attribute named `key` will be available via:
......@@ -182,13 +181,13 @@ class Wrapper(Type):
types_string = ','.join(str(t) for t in self.types).encode('utf-8')
fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_struct_%s_%s' % (fields_hex, types_hex)
return '_wrapper_%s_%s' % (fields_hex, types_hex)
def check_that_values_are_compatible(self, data, strict, allow_downcast):
wrap_instance = dict()
for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return Wrap(**wrap_instance)
return data if strict else Wrap(**wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
......@@ -199,14 +198,20 @@ class Wrapper(Type):
return self.check_that_values_are_compatible(data, strict, allow_downcast)
def values_eq(self, a, b):
# We check that a and b have expected attributes and strict values.
a = self.filter(a, strict=True)
b = self.filter(b, strict=True)
# Then we compare.
for i in range(self.length):
if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False
return True
def values_eq_approx(self, a, b):
# We check, wrap and round a and b if necessary.
a = self.filter(a, strict=False, allow_downcast=True)
b = self.filter(b, strict=False, allow_downcast=True)
# Then we compare.
for i in range(self.length):
if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False
......@@ -291,16 +296,15 @@ class Wrapper(Type):
sub = {'fail': '{this->setErrorOccurred(); this->cleanup(); return;}'}
struct_name = self.name
struct_name_defined = struct_name.upper()
struct_fields = ''
struct_declare = ''
struct_init = ''
struct_cleanup = ''
struct_extraction_methods = ''
struct_extract = ''
c_declare_list = []
c_init_list = []
c_cleanup_list = []
c_extract_list = []
for attribute_name, type_instance in zip(self.fields, self.types):
type_name = type_instance.__class__.__name__
c_declare_list.append(type_instance.c_declare(attribute_name, sub))
......@@ -317,10 +321,10 @@ class Wrapper(Type):
'extract_code': type_instance.c_extract(attribute_name, sub)
})
struct_fields = '\n'.join(c_declare_list)
struct_declare = '\n'.join(c_declare_list)
struct_init = '\n'.join(c_init_list)
struct_cleanup = '\n'.join(c_cleanup_list)
struct_extraction_methods = '\n\n'.join(c_extract_list)
struct_extract = '\n\n'.join(c_extract_list)
struct_extract_method = """
void extract(PyObject* object, int field_pos) {
switch(field_pos) {
......@@ -343,7 +347,7 @@ class Wrapper(Type):
struct %(struct_name)s {
/* Attributes, */
int %(struct_name)s_error;
%(struct_fields)s
%(struct_declare)s
/* Constructor. */
%(struct_name)s() {
......@@ -363,7 +367,7 @@ class Wrapper(Type):
}
/* Extraction methods. */
%(struct_extraction_methods)s
%(struct_extract)s
/* Extract method. */
%(struct_extract_method)s
......@@ -380,7 +384,7 @@ class Wrapper(Type):
""" % locals()
def c_code_cache_version(self):
return (1, 2)
return (1, 3)
def c_declare(self, name, sub, check_input=True):
struct_name = self.name
......@@ -388,8 +392,8 @@ class Wrapper(Type):
%(struct_name)s %(name)s;
""" % locals()
# c_init() and c_cleanup() are useless if we create the struct on stack
# because the struct has constructor and destructor.
# c_init() and c_cleanup() are useless if we create the struct
# on stack, as struct class has constructor and destructor.
def c_init(self, name, sub):
return ""
......@@ -401,7 +405,6 @@ class Wrapper(Type):
fail = sub['fail']
length = self.length
fields_list = '"%s"' % '", "'.join(self.fields)
check = 1 if check_input else 0
return """
const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) {
......@@ -411,7 +414,7 @@ class Wrapper(Type):
for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]);
if (o == NULL) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute %%s in object.", fields[i]);
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s
}
%(name)s.extract(o, i);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论