提交 fd552f23 authored 作者: notoraptor's avatar notoraptor

Add a COp test in test_wrapper.

Rewrite `Wrap` so that it depends on a Wrapper. Simplify code.
上级 5260c149
......@@ -799,22 +799,15 @@ class Op(utils.object2, PureOp, CLinkerOp):
# We add a default get_params() implementation which will try to detect params from the op
# if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception.
def get_params(self, node):
if hasattr(self, 'params_type'):
# If params_type is a Wrapper, we try to extract params from the op.
if isinstance(self.params_type, theano.gof.wrapper.Wrapper):
wrapper = self.params_type
op_has_wrap_attributes = True
for field in wrapper.fields:
if not hasattr(self, field):
op_has_wrap_attributes = False
break
if op_has_wrap_attributes:
wrap_dict = dict()
for i in range(wrapper.length):
field = wrapper.fields[i]
_type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.wrapper.Wrap(**wrap_dict)
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.wrapper.Wrapper):
wrapper = self.params_type
if hasattr(self, '__props__') and all(field in self.__props__ for field in wrapper.fields):
wrap_dict = dict()
for i in range(wrapper.length):
field = wrapper.fields[i]
_type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.wrapper.Wrap(wrapper, **wrap_dict)
raise theano.gof.utils.MethodNotDefined('get_params')
def prepare_node(self, node, storage_map, compute_map, impl):
......
#section support_code_apply
int APPLY_SPECIFIC(quadratic_function)(PyArrayObject* tensor, DTYPE_INPUT_0 a, DTYPE_INPUT_0 b, DTYPE_INPUT_0 c) {
NpyIter* iterator = NpyIter_New(tensor,
NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK,
NPY_KEEPORDER, NPY_NO_CASTING, NULL);
if(iterator == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Unable to iterate over a tensor for an elemwise operation.");
return -1;
}
NpyIter_IterNextFunc* get_next = NpyIter_GetIterNext(iterator, NULL);
char** data_ptr = NpyIter_GetDataPtrArray(iterator);
npy_intp* stride_ptr = NpyIter_GetInnerStrideArray(iterator);
npy_intp* innersize_ptr = NpyIter_GetInnerLoopSizePtr(iterator);
do {
char* data = *data_ptr;
npy_intp stride = *stride_ptr;
npy_intp count = *innersize_ptr;
while(count) {
DTYPE_INPUT_0 x = *((DTYPE_INPUT_0*)data);
*((DTYPE_INPUT_0*)data) = a*x*x + b*x + c;
data += stride;
--count;
}
} while(get_next(iterator));
NpyIter_Deallocate(iterator);
return 0;
}
int APPLY_SPECIFIC(compute_quadratic)(PyArrayObject* X, PyArrayObject** Y, QUADRATIC_WRAPPER* coeff) {
DTYPE_INPUT_0 a = (DTYPE_INPUT_0) (*(COEFF_TYPE*) PyArray_GETPTR1(coeff->a, 0)); // 0-D TensorType.
DTYPE_INPUT_0 b = coeff->b; // Scalar.
DTYPE_INPUT_0 c = (DTYPE_INPUT_0)PyFloat_AsDouble(coeff->c); // Generic.
Py_XDECREF(*Y);
*Y = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(X), PyArray_DIMS(X), TYPENUM_INPUT_0, PyArray_IS_F_CONTIGUOUS(X));
if (PyArray_CopyInto(*Y, X) != 0) {
PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output.");
return 1;
};
if (APPLY_SPECIFIC(quadratic_function)(*Y, a, b, c) != 0) {
PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function.");
return 1;
}
return 0;
}
......@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division
import theano
import numpy
from unittest import TestCase
from theano.gof import Op, Apply
from theano.gof import Op, COp, Apply
from theano import Generic
from theano.scalar import Scalar
from theano.tensor import TensorType
......@@ -18,7 +18,7 @@ generic_type = Generic()
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params.
class QuadraticFunction(Op):
class QuadraticOpFunc(Op):
__props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d,
b=scalar_type,
......@@ -39,7 +39,7 @@ class QuadraticFunction(Op):
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
def c_code_cache_version(self):
return (1, 2)
return (1, 3)
def c_support_code_apply(self, node, name):
float_type = node.inputs[0].type.dtype_specs()[1]
......@@ -82,9 +82,9 @@ class QuadraticFunction(Op):
float_typenum = numpy.dtype(node.inputs[0].type.dtype).num
coeff_type = 'npy_' + numpy.dtype(dtype).name
return """
%(float_type)s a = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(%(coeff)s.a, 0)); // 0-D TensorType.
%(float_type)s b = %(coeff)s.b; // Scalar.
%(float_type)s c = (%(float_type)s)PyFloat_AsDouble(%(coeff)s.c); // Generic.
%(float_type)s a = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(%(coeff)s->a, 0)); // 0-D TensorType.
%(float_type)s b = %(coeff)s->b; // Scalar.
%(float_type)s c = (%(float_type)s)PyFloat_AsDouble(%(coeff)s->c); // Generic.
Py_XDECREF(%(Y)s);
%(Y)s = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(%(X)s), PyArray_DIMS(%(X)s), %(float_typenum)s, PyArray_IS_F_CONTIGUOUS(%(X)s));
if (PyArray_CopyInto(%(Y)s, %(X)s) != 0) {
......@@ -98,29 +98,54 @@ class QuadraticFunction(Op):
""" % locals()
# Same op as above, but implemented as a COp (with C code in an external file).
class QuadraticCOpFunc(COp):
__props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d,
b=scalar_type,
c=generic_type)
def get_op_params(self):
return [('QUADRATIC_WRAPPER', self.params_type.name),
('COEFF_TYPE', 'npy_' + numpy.dtype(dtype).name)]
def __init__(self, a, b, c):
super(QuadraticCOpFunc, self).__init__('test_quadratic_function.c',
'APPLY_SPECIFIC(compute_quadratic)')
self.a = a
self.b = b
self.c = c
def make_node(self, x):
x = tensor.as_tensor_variable(x)
return Apply(self, [x], [x.type()])
class TestWrapper(TestCase):
def test_wrap_hash_and_eq(self):
w1 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
def test_hash_and_eq_wrap(self):
wp1 = Wrapper(a=Generic(), array=TensorType('int32', (False,)), floatting=Scalar('float32'),
npy_scalar=TensorType('float64', tuple()))
wp2 = Wrapper(a=Generic(), array=TensorType('int32', (False,)), floatting=Scalar('float32'),
npy_scalar=TensorType('float64', tuple()))
w1 = Wrap(wp1, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
w2 = Wrap(wp2, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 == w2
assert not (w1 != w2)
assert hash(w1) == hash(w2)
assert all(hasattr(w1, key) for key in ('a', 'b', 'array', 'floatting', 'npy_scalar'))
# Changing attributes names only.
w2 = Wrap(other_name=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
# Changing attributes types only.
w2 = Wrap(a=1, b='test string', array=[1, 2, 4, 5, 7], floatting=-4.5, npy_scalar=numpy.asarray(12))
# Changing attributes names only (a -> other_name).
wp2_other = Wrapper(other_name=Generic(), array=TensorType('int32', (False,)), floatting=Scalar('float32'),
npy_scalar=TensorType('float64', tuple()))
w2 = Wrap(wp2_other, other_name=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
# Changing attributes values only.
w2 = Wrap(a=1, b='string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
# Changing attributes values only (now a=2).
w2 = Wrap(wp2, a=2, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
# Changing NumPy array values.
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
# Changing NumPy array values (5 -> -5).
w2 = Wrap(wp2, a=1, array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
def test_wrapper_hash_and_eq(self):
def test_hash_and_eq_wrapper(self):
w1 = Wrapper(a1=TensorType('int64', (False, False)),
a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic())
......@@ -133,7 +158,7 @@ class TestWrapper(TestCase):
assert w1.name == w2.name
# Changing attributes names only.
w2 = Wrapper(a1=TensorType('int64', (False, False)),
other_name=TensorType('int64', (False, True, False, False, True)),
other_name=TensorType('int64', (False, True, False, False, True)), # a2 -> other_name
a3=Generic())
assert w1 != w2
# Changing attributes types only.
......@@ -157,7 +182,8 @@ class TestWrapper(TestCase):
a3=Generic())
# With a value that does not match the wrapper.
o = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'),
o = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'),
a2=random_tensor.astype('float32'),
a3=2000)
# should fail (o.a1 is not int32, o.a2 is not float64)
......@@ -168,7 +194,8 @@ class TestWrapper(TestCase):
w.filter(o, strict=False, allow_downcast=True)
# With a value that matches the wrapper.
o1 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
o1 = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float64'),
a3=2000)
# All should pass.
......@@ -176,15 +203,17 @@ class TestWrapper(TestCase):
w.filter(o1, strict=False, allow_downcast=False)
w.filter(o1, strict=False, allow_downcast=True)
# Check value_eq and value_eq_approx.
o2 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
# Check values_eq and values_eq_approx.
o2 = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float64'),
a3=2000)
assert w.values_eq(o1, o2)
assert w.values_eq_approx(o1, o2)
# Check value_eq_approx.
o3 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('float32'),
o3 = Wrap(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('float32'),
a2=random_tensor.astype('float64'),
a3=2000.0)
assert w.values_eq_approx(o1, o3)
......@@ -192,14 +221,18 @@ class TestWrapper(TestCase):
def test_op_params(self):
a, b, c = 2, 3, -7
x = tensor.matrix()
y = QuadraticFunction(a, b, c)(x)
f = theano.function([x], y)
y1 = QuadraticOpFunc(a, b, c)(x)
y2 = QuadraticCOpFunc(a, b, c)(x)
f1 = theano.function([x], y1)
f2 = theano.function([x], y2)
shape = (100, 100)
# The for-loop is here just to force profiling print something interesting.
# When running this test without this loop, profiling does not print neither list of classes nor list of ops
# (maybe because the function is extremely fast ?).
for i in range(50):
vx = numpy.random.normal(size=shape[0] * shape[1]).astype(dtype).reshape(*shape)
vy = f(vx)
vy1 = f1(vx)
vy2 = f2(vx)
ref = a * (vx**2) + b * vx + c
utt.assert_allclose(ref, vy)
utt.assert_allclose(vy1, vy2)
utt.assert_allclose(ref, vy1)
"""
Module for wrapping many Theano variables into one struct for op params.
Module for wrapping many Theano variables into one C struct for op params.
This module contains two classes:
- Wrapper: class to define the op params type.
- Wrap: internal convenient class to create an object that is compatible with a Wrapper-defined op params.
Example of usage:
- :class:`Wrapper`: main class to define the op params type.
- :class:`Wrap`: internal convenient class to create an object that is compatible with Wrapper-defined op params.
Importation:
Example of usage
----------------
from theano.gof.wrapper import Wrapper
Importation:
In an op you create:
.. code-block:: python
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=TensorType('float64', (True, False)))
from theano.gof.wrapper import Wrapper
If your op contains props `attr1` AND `attr2`, the op.get_params() method will
automatically try to look for it and generate an appropriate wrapped struct.
The props must be able to pass the filtering (not strict, downcasting allowed)
of corresponding types defined into Wrapper.
In an op you create:
__props__ = ('attr1', 'attr2')
def __init__(value_attr1, value_attr2):
self.attr1 = value_attr1
self.attr2 = value_attr2
.. code-block:: python
In perform() implementation (with params named `param`):
from theano.tensor import TensorType, dmatrix
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=dmatrix)
var1 = param.attr1
var2 = param.attr2
If your op contains props ``attr1`` *and* ``attr2``, the default ``op.get_params()`` implementation
will automatically try to look for it and generate an appropriate wrapped struct.
Props must be compatible with the corresponding types defined into the Wrapper
(we will try to convert and downcast if needed).
In c_code() implementation (with `param = sub['params']`):
.. code-block:: python
PyArrayObject* attr1 = param.attr1;
PyArrayObject* attr2 = param.attr2;
/* You won't need to free them or whatever else. */
__props__ = ('attr1', 'attr2')
def __init__(value_attr1, value_attr2):
self.attr1 = value_attr1
self.attr2 = value_attr2
In ``perform()`` implementation (with params named ``param``):
See `theano/gof/tests/test_wrapper.py` for a complete working example.
.. code-block:: python
var1 = param.attr1
var2 = param.attr2
In ``c_code()`` implementation (with ``param = sub['params']``):
.. code-block:: c
PyArrayObject* attr1 = param->attr1;
PyArrayObject* attr2 = param->attr2;
/* You won't need to free them or whatever else. */
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_wrapper.py``
for complete working examples.
"""
from __future__ import absolute_import, print_function, division
import re
import hashlib
import numpy
from theano.gof.utils import MethodNotDefined
from theano.gof.utils import MethodNotDefined, c_cpp_keywords
from theano.gof import Type
from theano.tensor.utils import hash_from_ndarray
# NB: Maybe we should check if an attribute name is a C/C++ keyword, and raise an error if so.
# These are some lists of C/C++ keywords:
# http://fr.cppreference.com/w/cpp/keyword
# http://fr.cppreference.com/w/c/keyword
class Wrap(dict):
......@@ -60,19 +67,31 @@ class Wrap(dict):
Internal convenient class to wrap many Python objects into one
(this class is not safe as the hash method does not check if values are effectively hashable).
Example:
>>> w = Wrap(attr1=1, attr2=2.0, attri='3')
>>> print(w.attr1, w.attr2, w.attri)
>>> d = dict(a=1, b=2, c='test')
>>> w2 = Wrap(**d)
>>> print(w2.a, w2.b, w2.c)
**Example:**
.. code-block:: python
from theano.gof.wrapper import *
from theano.scalar import Scalar
# You must create a Wrapper first:
wp = Wrapper(attr1=Scalar('int32'), key2=Scalar('float32'), field3=Scalar('int64'))
# Then you can create a Wrap with the wrapper defined above and values for attributes.
w = Wrap(wp, attr1=1, key2=2.0, field3=3)
print(w.attr1, w.key2, w.field3)
d = dict(attr1=1, key2=2, field3=-1)
w2 = Wrap(wp, **d)
print(w2.attr1, w2.key2, w2.field3)
"""
def __init__(self, **kwargs):
def __init__(self, wrapper, **kwargs):
if not isinstance(wrapper, Wrapper):
raise TypeError('Wrap: 1st constructor argument should be a Wrapper.')
for field in wrapper.fields:
if field not in kwargs:
raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field)
super(Wrap, self).__init__(**kwargs)
if len(kwargs) == 0:
raise TypeError('Wrap: cannot wrap empty data.')
self.__dict__.update(wrapper=wrapper)
def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())])
......@@ -82,33 +101,28 @@ class Wrap(dict):
raise AttributeError('Wrap: attribute "%s" does not exist.' % key)
return self[key]
def __setattr__(self, key, value):
raise NotImplementedError('Wrap is immutable')
def __setitem__(self, key, value):
raise NotImplementedError('Wrap is immutable')
def __delitem__(self, key):
raise NotImplementedError('Wrap is immutable')
def __hash__(self):
keys = sorted(self.keys())
types = []
attributes = []
for k in keys:
types += (type(self[k]),)
if isinstance(self[k], numpy.ndarray):
# Note: hash_from_ndarray returns a string, so the hash is not yet complete
# (__hash__ must return an integer).
attributes += (hash_from_ndarray(self[k]),)
else:
# No checking, data should be hashable.
attributes += (self[k],)
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
return hash((type(self), self.wrapper) + tuple(
# NB: Wrapped data should have been already filtered.
self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature()
for i in range(self.wrapper.length)
))
def __eq__(self, other):
if type(self) != type(other) or len(self) != len(other):
return False
for k in self:
if k not in other or not (isinstance(self[k], type(other[k])) and isinstance(other[k], type(self[k]))):
return False
if isinstance(self[k], numpy.ndarray):
if not numpy.allclose(self[k], other[k]):
return False
elif self[k] != other[k]:
return False
return True
return (type(self) == type(other) and self.wrapper == other.wrapper and all(
# NB: Wrapped data should have been already filtered.
self.wrapper.types[i].values_eq(self[self.wrapper.fields[i]], other[self.wrapper.fields[i]])
for i in range(self.wrapper.length)
))
def __ne__(self, other):
return not self.__eq__(other)
......@@ -121,18 +135,23 @@ class Wrapper(Type):
Wrapper constructor takes key-value args.
Key will be the name of the attribute in the struct.
Value is the Theano type of this attribute, ie. an instance of (a subclass of) Type
(eg. TensorType('int64', (False,))).
Value is the Theano type of this attribute, ie. an instance of (a subclass of) :class:`Type`
(eg. ``TensorType('int64', (False,))``).
In a Python code any attribute named `key` will be available via:
structObject.key
In a Python code any attribute named ``key`` will be available via::
In a C code, attributes created to represent an instance of the type associated to `key` will be available via:
structObject.key
structObject.dtype_key # e.g. from TensorType C code.
structObject.other_attribute_named_from_key
etc.
In a C code, attributes created to represent an instance of the type associated to ``key`` will be available via:
.. code-block:: c
structObject->key;
structObject->dtype_key; // e.g. from TensorType C code.
structObject->other_attribute_named_from_key;
/* etc. */
**NB**: This Type is not a complete type and should never be used for regular graph operations.
"""
def __init__(self, **kwargs):
......@@ -142,10 +161,14 @@ class Wrapper(Type):
for attribute_name in kwargs:
if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None:
raise SyntaxError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name)
if attribute_name in c_cpp_keywords:
print(len(c_cpp_keywords))
raise SyntaxError('Wrapper: "%s" is a potential C/C++ keyword and should not be used as attribute name.'
% attribute_name)
type_instance = kwargs[attribute_name]
type_name = type_instance.__class__.__name__
if not isinstance(type_instance, Type):
raise TypeError('Wrapper: attribute "%s" should inherit from theano Type, got "%s".'
raise TypeError('Wrapper: attribute "%s" should inherit from Theano Type, got "%s".'
% (attribute_name, type_name))
self.length = len(kwargs)
......@@ -164,49 +187,44 @@ class Wrapper(Type):
return hash((type(self),) + self.fields + self.types)
def generate_struct_name(self):
""""
This method tries to generate an unique name for the current instance.
This name is intended to be used as struct name in C code and as constant
definition to check if a similar Wrapper has already been created
(see c_support_code() below).
"""
# This method tries to generate an unique name for the current instance.
# This name is intended to be used as struct name in C code and as constant
# definition to check if a similar Wrapper has already been created
# (see c_support_code() below).
fields_string = ','.join(self.fields).encode('utf-8')
types_string = ','.join(str(t) for t in self.types).encode('utf-8')
fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_%s_%s' % (fields_hex, types_hex)
def check_that_values_are_compatible(self, data, strict, allow_downcast):
def wrap_data(self, data, strict, allow_downcast):
# Try to wrap data. Raise an exception if data does not respect the Wrapper's contract.
wrap_instance = dict()
for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return data if strict else Wrap(**wrap_instance)
return data if strict else Wrap(self, **wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, Wrap):
raise TypeError('%s: strict mode: data should be an instance of Wrap.' % self)
return self.check_that_values_are_compatible(data, strict, allow_downcast)
return self.wrap_data(data, strict, allow_downcast)
def values_eq(self, a, b):
# We check that a and b have expected attributes and strict values.
a = self.filter(a, strict=True)
b = self.filter(b, strict=True)
# Then we compare.
for i in range(self.length):
if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False
return True
return all(self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i]))
for i in range(self.length))
def values_eq_approx(self, a, b):
# We check, wrap and round a and b if necessary.
a = self.filter(a, strict=False, allow_downcast=True)
b = self.filter(b, strict=False, allow_downcast=True)
# Then we compare.
for i in range(self.length):
if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False
return True
return all(self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i]))
for i in range(self.length))
def c_compile_args(self, c_compiler):
c_compile_args_list = []
......@@ -375,28 +393,40 @@ class Wrapper(Type):
""" % locals()
def c_code_cache_version(self):
return (1, 4)
return (1, 5)
# As this struct has constructor and destructor, it could be instanciated on stack,
# but current implementations of C ops will then pass the instance by value at functions,
# so it's better to work directly with pointers.
def c_declare(self, name, sub, check_input=True):
struct_name = self.name
return """
%(struct_name)s %(name)s;
%(struct_name)s* %(name)s;
""" % locals()
# c_init() and c_cleanup() are useless if we create the struct
# on stack, as struct class has constructor and destructor.
def c_init(self, name, sub):
return ""
# NB: It seems c_init() is not called for an op param.
# So the real initialization is done at top of c_extract.
return """
%(nams)s = NULL;
""" % locals()
def c_cleanup(self, name, sub):
return ""
return """
delete %(name)s;
%(name)s = NULL;
""" % locals()
def c_extract(self, name, sub, check_input=True):
struct_name = self.name
fail = sub['fail']
length = self.length
fields_list = '"%s"' % '", "'.join(self.fields)
return """
/* Seems c_init() is not called for a op param. So I call `new` here. */
%(name)s = new %(struct_name)s;
const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None.");
......@@ -408,8 +438,8 @@ class Wrapper(Type):
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s
}
%(name)s.extract(o, i);
if (%(name)s.errorOccurred()) {
%(name)s->extract(o, i);
if (%(name)s->errorOccurred()) {
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
fprintf(stderr, "\\nWrapper: error when extracting value for attribute \\"%%s\\".\\n", fields[i]);
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论