提交 fd552f23 authored 作者: notoraptor's avatar notoraptor

Add a COp test in test_wrapper.

Rewrite `Wrap` so that it depends on a Wrapper. Simplify code.
上级 5260c149
...@@ -799,22 +799,15 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -799,22 +799,15 @@ class Op(utils.object2, PureOp, CLinkerOp):
# We add a default get_params() implementation which will try to detect params from the op # We add a default get_params() implementation which will try to detect params from the op
# if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception. # if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception.
def get_params(self, node): def get_params(self, node):
if hasattr(self, 'params_type'): if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.wrapper.Wrapper):
# If params_type is a Wrapper, we try to extract params from the op. wrapper = self.params_type
if isinstance(self.params_type, theano.gof.wrapper.Wrapper): if hasattr(self, '__props__') and all(field in self.__props__ for field in wrapper.fields):
wrapper = self.params_type wrap_dict = dict()
op_has_wrap_attributes = True for i in range(wrapper.length):
for field in wrapper.fields: field = wrapper.fields[i]
if not hasattr(self, field): _type = wrapper.types[i]
op_has_wrap_attributes = False wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
break return theano.gof.wrapper.Wrap(wrapper, **wrap_dict)
if op_has_wrap_attributes:
wrap_dict = dict()
for i in range(wrapper.length):
field = wrapper.fields[i]
_type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.wrapper.Wrap(**wrap_dict)
raise theano.gof.utils.MethodNotDefined('get_params') raise theano.gof.utils.MethodNotDefined('get_params')
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
......
#section support_code_apply
int APPLY_SPECIFIC(quadratic_function)(PyArrayObject* tensor, DTYPE_INPUT_0 a, DTYPE_INPUT_0 b, DTYPE_INPUT_0 c) {
NpyIter* iterator = NpyIter_New(tensor,
NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK,
NPY_KEEPORDER, NPY_NO_CASTING, NULL);
if(iterator == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Unable to iterate over a tensor for an elemwise operation.");
return -1;
}
NpyIter_IterNextFunc* get_next = NpyIter_GetIterNext(iterator, NULL);
char** data_ptr = NpyIter_GetDataPtrArray(iterator);
npy_intp* stride_ptr = NpyIter_GetInnerStrideArray(iterator);
npy_intp* innersize_ptr = NpyIter_GetInnerLoopSizePtr(iterator);
do {
char* data = *data_ptr;
npy_intp stride = *stride_ptr;
npy_intp count = *innersize_ptr;
while(count) {
DTYPE_INPUT_0 x = *((DTYPE_INPUT_0*)data);
*((DTYPE_INPUT_0*)data) = a*x*x + b*x + c;
data += stride;
--count;
}
} while(get_next(iterator));
NpyIter_Deallocate(iterator);
return 0;
}
int APPLY_SPECIFIC(compute_quadratic)(PyArrayObject* X, PyArrayObject** Y, QUADRATIC_WRAPPER* coeff) {
DTYPE_INPUT_0 a = (DTYPE_INPUT_0) (*(COEFF_TYPE*) PyArray_GETPTR1(coeff->a, 0)); // 0-D TensorType.
DTYPE_INPUT_0 b = coeff->b; // Scalar.
DTYPE_INPUT_0 c = (DTYPE_INPUT_0)PyFloat_AsDouble(coeff->c); // Generic.
Py_XDECREF(*Y);
*Y = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(X), PyArray_DIMS(X), TYPENUM_INPUT_0, PyArray_IS_F_CONTIGUOUS(X));
if (PyArray_CopyInto(*Y, X) != 0) {
PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output.");
return 1;
};
if (APPLY_SPECIFIC(quadratic_function)(*Y, a, b, c) != 0) {
PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function.");
return 1;
}
return 0;
}
...@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division ...@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division
import theano import theano
import numpy import numpy
from unittest import TestCase from unittest import TestCase
from theano.gof import Op, Apply from theano.gof import Op, COp, Apply
from theano import Generic from theano import Generic
from theano.scalar import Scalar from theano.scalar import Scalar
from theano.tensor import TensorType from theano.tensor import TensorType
...@@ -18,7 +18,7 @@ generic_type = Generic() ...@@ -18,7 +18,7 @@ generic_type = Generic()
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params. # A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params.
class QuadraticFunction(Op): class QuadraticOpFunc(Op):
__props__ = ('a', 'b', 'c') __props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d, params_type = Wrapper(a=tensor_type_0d,
b=scalar_type, b=scalar_type,
...@@ -39,7 +39,7 @@ class QuadraticFunction(Op): ...@@ -39,7 +39,7 @@ class QuadraticFunction(Op):
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, 2) return (1, 3)
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
float_type = node.inputs[0].type.dtype_specs()[1] float_type = node.inputs[0].type.dtype_specs()[1]
...@@ -82,9 +82,9 @@ class QuadraticFunction(Op): ...@@ -82,9 +82,9 @@ class QuadraticFunction(Op):
float_typenum = numpy.dtype(node.inputs[0].type.dtype).num float_typenum = numpy.dtype(node.inputs[0].type.dtype).num
coeff_type = 'npy_' + numpy.dtype(dtype).name coeff_type = 'npy_' + numpy.dtype(dtype).name
return """ return """
%(float_type)s a = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(%(coeff)s.a, 0)); // 0-D TensorType. %(float_type)s a = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(%(coeff)s->a, 0)); // 0-D TensorType.
%(float_type)s b = %(coeff)s.b; // Scalar. %(float_type)s b = %(coeff)s->b; // Scalar.
%(float_type)s c = (%(float_type)s)PyFloat_AsDouble(%(coeff)s.c); // Generic. %(float_type)s c = (%(float_type)s)PyFloat_AsDouble(%(coeff)s->c); // Generic.
Py_XDECREF(%(Y)s); Py_XDECREF(%(Y)s);
%(Y)s = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(%(X)s), PyArray_DIMS(%(X)s), %(float_typenum)s, PyArray_IS_F_CONTIGUOUS(%(X)s)); %(Y)s = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(%(X)s), PyArray_DIMS(%(X)s), %(float_typenum)s, PyArray_IS_F_CONTIGUOUS(%(X)s));
if (PyArray_CopyInto(%(Y)s, %(X)s) != 0) { if (PyArray_CopyInto(%(Y)s, %(X)s) != 0) {
...@@ -98,29 +98,54 @@ class QuadraticFunction(Op): ...@@ -98,29 +98,54 @@ class QuadraticFunction(Op):
""" % locals() """ % locals()
# Same op as above, but implemented as a COp (with C code in an external file).
class QuadraticCOpFunc(COp):
__props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d,
b=scalar_type,
c=generic_type)
def get_op_params(self):
return [('QUADRATIC_WRAPPER', self.params_type.name),
('COEFF_TYPE', 'npy_' + numpy.dtype(dtype).name)]
def __init__(self, a, b, c):
super(QuadraticCOpFunc, self).__init__('test_quadratic_function.c',
'APPLY_SPECIFIC(compute_quadratic)')
self.a = a
self.b = b
self.c = c
def make_node(self, x):
x = tensor.as_tensor_variable(x)
return Apply(self, [x], [x.type()])
class TestWrapper(TestCase): class TestWrapper(TestCase):
def test_wrap_hash_and_eq(self): def test_hash_and_eq_wrap(self):
w1 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) wp1 = Wrapper(a=Generic(), array=TensorType('int32', (False,)), floatting=Scalar('float32'),
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) npy_scalar=TensorType('float64', tuple()))
wp2 = Wrapper(a=Generic(), array=TensorType('int32', (False,)), floatting=Scalar('float32'),
npy_scalar=TensorType('float64', tuple()))
w1 = Wrap(wp1, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
w2 = Wrap(wp2, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 == w2 assert w1 == w2
assert not (w1 != w2) assert not (w1 != w2)
assert hash(w1) == hash(w2) assert hash(w1) == hash(w2)
assert all(hasattr(w1, key) for key in ('a', 'b', 'array', 'floatting', 'npy_scalar')) # Changing attributes names only (a -> other_name).
# Changing attributes names only. wp2_other = Wrapper(other_name=Generic(), array=TensorType('int32', (False,)), floatting=Scalar('float32'),
w2 = Wrap(other_name=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) npy_scalar=TensorType('float64', tuple()))
assert w1 != w2 w2 = Wrap(wp2_other, other_name=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
# Changing attributes types only.
w2 = Wrap(a=1, b='test string', array=[1, 2, 4, 5, 7], floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2 assert w1 != w2
# Changing attributes values only. # Changing attributes values only (now a=2).
w2 = Wrap(a=1, b='string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Wrap(wp2, a=2, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2 assert w1 != w2
# Changing NumPy array values. # Changing NumPy array values (5 -> -5).
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Wrap(wp2, a=1, array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2 assert w1 != w2
def test_wrapper_hash_and_eq(self): def test_hash_and_eq_wrapper(self):
w1 = Wrapper(a1=TensorType('int64', (False, False)), w1 = Wrapper(a1=TensorType('int64', (False, False)),
a2=TensorType('int64', (False, True, False, False, True)), a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic()) a3=Generic())
...@@ -133,7 +158,7 @@ class TestWrapper(TestCase): ...@@ -133,7 +158,7 @@ class TestWrapper(TestCase):
assert w1.name == w2.name assert w1.name == w2.name
# Changing attributes names only. # Changing attributes names only.
w2 = Wrapper(a1=TensorType('int64', (False, False)), w2 = Wrapper(a1=TensorType('int64', (False, False)),
other_name=TensorType('int64', (False, True, False, False, True)), other_name=TensorType('int64', (False, True, False, False, True)), # a2 -> other_name
a3=Generic()) a3=Generic())
assert w1 != w2 assert w1 != w2
# Changing attributes types only. # Changing attributes types only.
...@@ -157,7 +182,8 @@ class TestWrapper(TestCase): ...@@ -157,7 +182,8 @@ class TestWrapper(TestCase):
a3=Generic()) a3=Generic())
# With a value that does not match the wrapper. # With a value that does not match the wrapper.
o = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'), o = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'),
a2=random_tensor.astype('float32'), a2=random_tensor.astype('float32'),
a3=2000) a3=2000)
# should fail (o.a1 is not int32, o.a2 is not float64) # should fail (o.a1 is not int32, o.a2 is not float64)
...@@ -168,7 +194,8 @@ class TestWrapper(TestCase): ...@@ -168,7 +194,8 @@ class TestWrapper(TestCase):
w.filter(o, strict=False, allow_downcast=True) w.filter(o, strict=False, allow_downcast=True)
# With a value that matches the wrapper. # With a value that matches the wrapper.
o1 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'), o1 = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float64'), a2=random_tensor.astype('float64'),
a3=2000) a3=2000)
# All should pass. # All should pass.
...@@ -176,15 +203,17 @@ class TestWrapper(TestCase): ...@@ -176,15 +203,17 @@ class TestWrapper(TestCase):
w.filter(o1, strict=False, allow_downcast=False) w.filter(o1, strict=False, allow_downcast=False)
w.filter(o1, strict=False, allow_downcast=True) w.filter(o1, strict=False, allow_downcast=True)
# Check value_eq and value_eq_approx. # Check values_eq and values_eq_approx.
o2 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'), o2 = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float64'), a2=random_tensor.astype('float64'),
a3=2000) a3=2000)
assert w.values_eq(o1, o2) assert w.values_eq(o1, o2)
assert w.values_eq_approx(o1, o2) assert w.values_eq_approx(o1, o2)
# Check value_eq_approx. # Check value_eq_approx.
o3 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('float32'), o3 = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('float32'),
a2=random_tensor.astype('float64'), a2=random_tensor.astype('float64'),
a3=2000.0) a3=2000.0)
assert w.values_eq_approx(o1, o3) assert w.values_eq_approx(o1, o3)
...@@ -192,14 +221,18 @@ class TestWrapper(TestCase): ...@@ -192,14 +221,18 @@ class TestWrapper(TestCase):
def test_op_params(self): def test_op_params(self):
a, b, c = 2, 3, -7 a, b, c = 2, 3, -7
x = tensor.matrix() x = tensor.matrix()
y = QuadraticFunction(a, b, c)(x) y1 = QuadraticOpFunc(a, b, c)(x)
f = theano.function([x], y) y2 = QuadraticCOpFunc(a, b, c)(x)
f1 = theano.function([x], y1)
f2 = theano.function([x], y2)
shape = (100, 100) shape = (100, 100)
# The for-loop is here just to force profiling print something interesting. # The for-loop is here just to force profiling print something interesting.
# When running this test without this loop, profiling does not print neither list of classes nor list of ops # When running this test without this loop, profiling does not print neither list of classes nor list of ops
# (maybe because the function is extremely fast ?). # (maybe because the function is extremely fast ?).
for i in range(50): for i in range(50):
vx = numpy.random.normal(size=shape[0] * shape[1]).astype(dtype).reshape(*shape) vx = numpy.random.normal(size=shape[0] * shape[1]).astype(dtype).reshape(*shape)
vy = f(vx) vy1 = f1(vx)
vy2 = f2(vx)
ref = a * (vx**2) + b * vx + c ref = a * (vx**2) + b * vx + c
utt.assert_allclose(ref, vy) utt.assert_allclose(vy1, vy2)
utt.assert_allclose(ref, vy1)
""" """
Module for wrapping many Theano variables into one struct for op params. Module for wrapping many Theano variables into one C struct for op params.
This module contains two classes: This module contains two classes:
- Wrapper: class to define the op params type.
- Wrap: internal convenient class to create an object that is compatible with a Wrapper-defined op params.
Example of usage: - :class:`Wrapper`: main class to define the op params type.
- :class:`Wrap`: internal convenient class to create an object that is compatible with Wrapper-defined op params.
Importation: Example of usage
----------------
from theano.gof.wrapper import Wrapper Importation:
In an op you create: .. code-block:: python
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=TensorType('float64', (True, False))) from theano.gof.wrapper import Wrapper
If your op contains props `attr1` AND `attr2`, the op.get_params() method will In an op you create:
automatically try to look for it and generate an appropriate wrapped struct.
The props must be able to pass the filtering (not strict, downcasting allowed)
of corresponding types defined into Wrapper.
__props__ = ('attr1', 'attr2') .. code-block:: python
def __init__(value_attr1, value_attr2):
self.attr1 = value_attr1
self.attr2 = value_attr2
In perform() implementation (with params named `param`): from theano.tensor import TensorType, dmatrix
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=dmatrix)
var1 = param.attr1 If your op contains props ``attr1`` *and* ``attr2``, the default ``op.get_params()`` implementation
var2 = param.attr2 will automatically try to look for it and generate an appropriate wrapped struct.
Props must be compatible with the corresponding types defined into the Wrapper
(we will try to convert and downcast if needed).
In c_code() implementation (with `param = sub['params']`): .. code-block:: python
PyArrayObject* attr1 = param.attr1; __props__ = ('attr1', 'attr2')
PyArrayObject* attr2 = param.attr2; def __init__(value_attr1, value_attr2):
/* You won't need to free them or whatever else. */ self.attr1 = value_attr1
self.attr2 = value_attr2
In ``perform()`` implementation (with params named ``param``):
See `theano/gof/tests/test_wrapper.py` for a complete working example. .. code-block:: python
var1 = param.attr1
var2 = param.attr2
In ``c_code()`` implementation (with ``param = sub['params']``):
.. code-block:: c
PyArrayObject* attr1 = param->attr1;
PyArrayObject* attr2 = param->attr2;
/* You won't need to free them or whatever else. */
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_wrapper.py``
for complete working examples.
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import re import re
import hashlib import hashlib
import numpy from theano.gof.utils import MethodNotDefined, c_cpp_keywords
from theano.gof.utils import MethodNotDefined
from theano.gof import Type from theano.gof import Type
from theano.tensor.utils import hash_from_ndarray
# NB: Maybe we should check if an attribute name is a C/C++ keyword, and raise an error if so.
# These are some lists of C/C++ keywords:
# http://fr.cppreference.com/w/cpp/keyword
# http://fr.cppreference.com/w/c/keyword
class Wrap(dict): class Wrap(dict):
...@@ -60,19 +67,31 @@ class Wrap(dict): ...@@ -60,19 +67,31 @@ class Wrap(dict):
Internal convenient class to wrap many Python objects into one Internal convenient class to wrap many Python objects into one
(this class is not safe as the hash method does not check if values are effectively hashable). (this class is not safe as the hash method does not check if values are effectively hashable).
Example: **Example:**
>>> w = Wrap(attr1=1, attr2=2.0, attri='3')
>>> print(w.attr1, w.attr2, w.attri) .. code-block:: python
>>> d = dict(a=1, b=2, c='test')
>>> w2 = Wrap(**d) from theano.gof.wrapper import *
>>> print(w2.a, w2.b, w2.c) from theano.scalar import Scalar
# You must create a Wrapper first:
wp = Wrapper(attr1=Scalar('int32'), key2=Scalar('float32'), field3=Scalar('int64'))
# Then you can create a Wrap with the wrapper defined above and values for attributes.
w = Wrap(wp, attr1=1, key2=2.0, field3=3)
print(w.attr1, w.key2, w.field3)
d = dict(attr1=1, key2=2, field3=-1)
w2 = Wrap(wp, **d)
print(w2.attr1, w2.key2, w2.field3)
""" """
def __init__(self, **kwargs): def __init__(self, wrapper, **kwargs):
if not isinstance(wrapper, Wrapper):
raise TypeError('Wrap: 1st constructor argument should be a Wrapper.')
for field in wrapper.fields:
if field not in kwargs:
raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field)
super(Wrap, self).__init__(**kwargs) super(Wrap, self).__init__(**kwargs)
if len(kwargs) == 0: self.__dict__.update(wrapper=wrapper)
raise TypeError('Wrap: cannot wrap empty data.')
def __repr__(self): def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())]) return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())])
...@@ -82,33 +101,28 @@ class Wrap(dict): ...@@ -82,33 +101,28 @@ class Wrap(dict):
raise AttributeError('Wrap: attribute "%s" does not exist.' % key) raise AttributeError('Wrap: attribute "%s" does not exist.' % key)
return self[key] return self[key]
def __setattr__(self, key, value):
raise NotImplementedError('Wrap is immutable')
def __setitem__(self, key, value):
raise NotImplementedError('Wrap is immutable')
def __delitem__(self, key):
raise NotImplementedError('Wrap is immutable')
def __hash__(self): def __hash__(self):
keys = sorted(self.keys()) return hash((type(self), self.wrapper) + tuple(
types = [] # NB: Wrapped data should have been already filtered.
attributes = [] self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature()
for k in keys: for i in range(self.wrapper.length)
types += (type(self[k]),) ))
if isinstance(self[k], numpy.ndarray):
# Note: hash_from_ndarray returns a string, so the hash is not yet complete
# (__hash__ must return an integer).
attributes += (hash_from_ndarray(self[k]),)
else:
# No checking, data should be hashable.
attributes += (self[k],)
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other) or len(self) != len(other): return (type(self) == type(other) and self.wrapper == other.wrapper and all(
return False # NB: Wrapped data should have been already filtered.
for k in self: self.wrapper.types[i].values_eq(self[self.wrapper.fields[i]], other[self.wrapper.fields[i]])
if k not in other or not (isinstance(self[k], type(other[k])) and isinstance(other[k], type(self[k]))): for i in range(self.wrapper.length)
return False ))
if isinstance(self[k], numpy.ndarray):
if not numpy.allclose(self[k], other[k]):
return False
elif self[k] != other[k]:
return False
return True
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
...@@ -121,18 +135,23 @@ class Wrapper(Type): ...@@ -121,18 +135,23 @@ class Wrapper(Type):
Wrapper constructor takes key-value args. Wrapper constructor takes key-value args.
Key will be the name of the attribute in the struct. Key will be the name of the attribute in the struct.
Value is the Theano type of this attribute, ie. an instance of (a subclass of) Type Value is the Theano type of this attribute, ie. an instance of (a subclass of) :class:`Type`
(eg. TensorType('int64', (False,))). (eg. ``TensorType('int64', (False,))``).
In a Python code any attribute named `key` will be available via: In a Python code any attribute named ``key`` will be available via::
structObject.key
In a C code, attributes created to represent an instance of the type associated to `key` will be available via:
structObject.key structObject.key
structObject.dtype_key # e.g. from TensorType C code.
structObject.other_attribute_named_from_key
etc.
In a C code, attributes created to represent an instance of the type associated to ``key`` will be available via:
.. code-block:: c
structObject->key;
structObject->dtype_key; // e.g. from TensorType C code.
structObject->other_attribute_named_from_key;
/* etc. */
**NB**: This Type is not a complete type and should never be used for regular graph operations.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -142,10 +161,14 @@ class Wrapper(Type): ...@@ -142,10 +161,14 @@ class Wrapper(Type):
for attribute_name in kwargs: for attribute_name in kwargs:
if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None: if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None:
raise SyntaxError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name) raise SyntaxError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name)
if attribute_name in c_cpp_keywords:
print(len(c_cpp_keywords))
raise SyntaxError('Wrapper: "%s" is a potential C/C++ keyword and should not be used as attribute name.'
% attribute_name)
type_instance = kwargs[attribute_name] type_instance = kwargs[attribute_name]
type_name = type_instance.__class__.__name__ type_name = type_instance.__class__.__name__
if not isinstance(type_instance, Type): if not isinstance(type_instance, Type):
raise TypeError('Wrapper: attribute "%s" should inherit from theano Type, got "%s".' raise TypeError('Wrapper: attribute "%s" should inherit from Theano Type, got "%s".'
% (attribute_name, type_name)) % (attribute_name, type_name))
self.length = len(kwargs) self.length = len(kwargs)
...@@ -164,49 +187,44 @@ class Wrapper(Type): ...@@ -164,49 +187,44 @@ class Wrapper(Type):
return hash((type(self),) + self.fields + self.types) return hash((type(self),) + self.fields + self.types)
def generate_struct_name(self): def generate_struct_name(self):
"""" # This method tries to generate an unique name for the current instance.
This method tries to generate an unique name for the current instance. # This name is intended to be used as struct name in C code and as constant
This name is intended to be used as struct name in C code and as constant # definition to check if a similar Wrapper has already been created
definition to check if a similar Wrapper has already been created # (see c_support_code() below).
(see c_support_code() below).
"""
fields_string = ','.join(self.fields).encode('utf-8') fields_string = ','.join(self.fields).encode('utf-8')
types_string = ','.join(str(t) for t in self.types).encode('utf-8') types_string = ','.join(str(t) for t in self.types).encode('utf-8')
fields_hex = hashlib.md5(fields_string).hexdigest() fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest() types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_%s_%s' % (fields_hex, types_hex) return '_wrapper_%s_%s' % (fields_hex, types_hex)
def check_that_values_are_compatible(self, data, strict, allow_downcast): def wrap_data(self, data, strict, allow_downcast):
# Try to wrap data. Raise an exception if data does not respect the Wrapper's contract.
wrap_instance = dict() wrap_instance = dict()
for i in range(self.length): for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast) wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return data if strict else Wrap(**wrap_instance) return data if strict else Wrap(self, **wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes. # Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, Wrap): if strict and not isinstance(data, Wrap):
raise TypeError('%s: strict mode: data should be an instance of Wrap.' % self) raise TypeError('%s: strict mode: data should be an instance of Wrap.' % self)
return self.check_that_values_are_compatible(data, strict, allow_downcast) return self.wrap_data(data, strict, allow_downcast)
def values_eq(self, a, b): def values_eq(self, a, b):
# We check that a and b have expected attributes and strict values. # We check that a and b have expected attributes and strict values.
a = self.filter(a, strict=True) a = self.filter(a, strict=True)
b = self.filter(b, strict=True) b = self.filter(b, strict=True)
# Then we compare. # Then we compare.
for i in range(self.length): return all(self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i]))
if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])): for i in range(self.length))
return False
return True
def values_eq_approx(self, a, b): def values_eq_approx(self, a, b):
# We check, wrap and round a and b if necessary. # We check, wrap and round a and b if necessary.
a = self.filter(a, strict=False, allow_downcast=True) a = self.filter(a, strict=False, allow_downcast=True)
b = self.filter(b, strict=False, allow_downcast=True) b = self.filter(b, strict=False, allow_downcast=True)
# Then we compare. # Then we compare.
for i in range(self.length): return all(self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i]))
if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])): for i in range(self.length))
return False
return True
def c_compile_args(self, c_compiler): def c_compile_args(self, c_compiler):
c_compile_args_list = [] c_compile_args_list = []
...@@ -375,28 +393,40 @@ class Wrapper(Type): ...@@ -375,28 +393,40 @@ class Wrapper(Type):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, 4) return (1, 5)
# As this struct has constructor and destructor, it could be instanciated on stack,
# but current implementations of C ops will then pass the instance by value at functions,
# so it's better to work directly with pointers.
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
struct_name = self.name struct_name = self.name
return """ return """
%(struct_name)s %(name)s; %(struct_name)s* %(name)s;
""" % locals() """ % locals()
# c_init() and c_cleanup() are useless if we create the struct
# on stack, as struct class has constructor and destructor.
def c_init(self, name, sub): def c_init(self, name, sub):
return "" # NB: It seems c_init() is not called for an op param.
# So the real initialization is done at top of c_extract.
return """
%(nams)s = NULL;
""" % locals()
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return """
delete %(name)s;
%(name)s = NULL;
""" % locals()
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
struct_name = self.name
fail = sub['fail'] fail = sub['fail']
length = self.length length = self.length
fields_list = '"%s"' % '", "'.join(self.fields) fields_list = '"%s"' % '", "'.join(self.fields)
return """ return """
/* Seems c_init() is not called for a op param. So I call `new` here. */
%(name)s = new %(struct_name)s;
const char* fields[] = {%(fields_list)s}; const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None."); PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
...@@ -408,8 +438,8 @@ class Wrapper(Type): ...@@ -408,8 +438,8 @@ class Wrapper(Type):
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]); PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s %(fail)s
} }
%(name)s.extract(o, i); %(name)s->extract(o, i);
if (%(name)s.errorOccurred()) { if (%(name)s->errorOccurred()) {
/* The extract code from attribute type should have already raised a Python exception, /* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */ * so we just print the attribute name in stderr. */
fprintf(stderr, "\\nWrapper: error when extracting value for attribute \\"%%s\\".\\n", fields[i]); fprintf(stderr, "\\nWrapper: error when extracting value for attribute \\"%%s\\".\\n", fields[i]);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论