提交 e04f0531 authored 作者: notoraptor's avatar notoraptor

Important renaming for classes and files:

- class Wrap -> Params - class Wrapper -> ParamsType - file wrapper.py -> params_type.py - file test_wrapper.py -> test_params_type.py Protect Wrap private fields and rewrite Wrap __repr__().
上级 3c440a74
...@@ -74,9 +74,9 @@ attribute :attr:`params_type` to an instance of your params Type. ...@@ -74,9 +74,9 @@ attribute :attr:`params_type` to an instance of your params Type.
.. note:: .. note::
If you want to have multiple parameters, Theano provides the convenient class If you want to have multiple parameters, Theano provides the convenient class
:class:`theano.gof.wrapper.Wrapper` that allows to bundle many parameters into :class:`theano.gof.params_type.ParamsType` that allows to bundle many parameters into
one object that will be available in both Python (as a Python object) and C code (as a struct). one object that will be available in both Python (as a Python object) and C code (as a struct).
See :ref:`Wrapper tutorial and API documentation <libdoc_gof_wrapper>` for more infos. See :ref:`ParamsType tutorial and API documentation <libdoc_gof_wrapper>` for more infos.
For example if we decide to use an int as the params the following For example if we decide to use an int as the params the following
would be appropriate: would be appropriate:
......
...@@ -17,5 +17,5 @@ ...@@ -17,5 +17,5 @@
fgraph fgraph
toolbox toolbox
type type
wrapper params_type
utils utils
.. _libdoc_gof_wrapper: .. _libdoc_gof_wrapper:
======================================================== ============================================================
:mod:`theano.gof.wrapper` -- Wrapper class for op params :mod:`theano.gof.params_type` -- Wrapper class for op params
======================================================== ============================================================
--------- ---------
Reference Reference
--------- ---------
.. automodule:: theano.gof.wrapper .. automodule:: theano.gof.params_type
:platform: Unix, Windows :platform: Unix, Windows
:synopsis: Wrapper class for op params :synopsis: Wrapper class for op params
:members: :members:
......
...@@ -80,7 +80,7 @@ from theano.gof.type import \ ...@@ -80,7 +80,7 @@ from theano.gof.type import \
from theano.gof.utils import \ from theano.gof.utils import \
hashtype, object2, MethodNotDefined hashtype, object2, MethodNotDefined
from theano.gof.wrapper import Wrapper, Wrap from theano.gof.params_type import ParamsType, Params
import theano import theano
......
...@@ -797,18 +797,18 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -797,18 +797,18 @@ class Op(utils.object2, PureOp, CLinkerOp):
""" """
# We add a default get_params() implementation which will try to detect params from the op # We add a default get_params() implementation which will try to detect params from the op
# if params_type is set to a Wrapper. If not, we raise a MethodNotDefined exception. # if params_type is set to a ParamsType. If not, we raise a MethodNotDefined exception.
def get_params(self, node): def get_params(self, node):
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper): if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.ParamsType):
wrapper = self.params_type wrapper = self.params_type
if not all(hasattr(self, field) for field in wrapper.fields): if not all(hasattr(self, field) for field in wrapper.fields):
raise AttributeError('%s: missing attributes for Wrapper parameter.' % type(self).__name__) raise AttributeError('%s: missing attributes for ParamsType parameter.' % type(self).__name__)
wrap_dict = dict() wrap_dict = dict()
for i in range(wrapper.length): for i in range(wrapper.length):
field = wrapper.fields[i] field = wrapper.fields[i]
_type = wrapper.types[i] _type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True) wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.Wrap(wrapper, **wrap_dict) return theano.gof.Params(wrapper, **wrap_dict)
raise theano.gof.utils.MethodNotDefined('get_params') raise theano.gof.utils.MethodNotDefined('get_params')
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
...@@ -1393,19 +1393,19 @@ class COp(Op): ...@@ -1393,19 +1393,19 @@ class COp(Op):
The names must be strings that are not a C keyword and the The names must be strings that are not a C keyword and the
values must be strings of literal C representations. values must be strings of literal C representations.
If op uses a :class:`theano.gof.wrapper.Wrapper` as ``params_type``, If op uses a :class:`theano.gof.params_type.ParamsType` as ``params_type``,
it returns: it returns:
- a default macro ``APPLY_SPECIFIC_WRAPPER`` which defines the class name of the - a default macro ``PARAMS_TYPE`` which defines the class name of the
corresponding C struct. corresponding C struct.
- a macro ``DTYPE_PARAM_key`` for every ``key`` in the Wrapper for which associated - a macro ``DTYPE_PARAM_key`` for every ``key`` in the ParamsType for which associated
type implements the method :func:`theano.gof.type.CLinkerType.c_element_type`. type implements the method :func:`theano.gof.type.CLinkerType.c_element_type`.
``DTYPE_PARAM_key`` defines the primitive C type name of an item in a variable ``DTYPE_PARAM_key`` defines the primitive C type name of an item in a variable
associated to ``key``. associated to ``key``.
""" """
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper): if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.ParamsType):
wrapper = self.params_type wrapper = self.params_type
params = [('APPLY_SPECIFIC_WRAPPER', wrapper.name)] params = [('PARAMS_TYPE', wrapper.name)]
for i in range(wrapper.length): for i in range(wrapper.length):
try: try:
params.append(('DTYPE_PARAM_' + wrapper.fields[i], wrapper.types[i].c_element_type())) params.append(('DTYPE_PARAM_' + wrapper.fields[i], wrapper.types[i].c_element_type()))
......
""" """
Module for wrapping many Op parameters into one object available in both Python and C code. Module for wrapping many Op parameters into one object available in both Python and C code.
The module provides the main public class :class:`Wrapper` that allows to bundle many Theano types The module provides the main public class :class:`ParamsType` that allows to bundle many Theano types
into one parameter type, and an internal convenient class :class:`Wrap` which will be automatically into one parameter type, and an internal convenient class :class:`Params` which will be automatically
used to create a Wrap object that is compatible with the Wrapper-defined type. used to create a Params object that is compatible with the ParamsType defined.
The Wrap object will be available in both Python code (as a standard Python object) and C code The Params object will be available in both Python code (as a standard Python object) and C code
(as a specific struct with parameters as struct fields). To be fully-available in C code, Theano (as a specific struct with parameters as struct fields). To be fully-available in C code, Theano
types wrapped into Wrapper must provide a C interface (e.g. TensorType, Scalar, GpuArrayType, types wrapped into a ParamsType must provide a C interface (e.g. TensorType, Scalar, GpuArrayType,
or your own type. See :ref:`extending_op_params` for more details). or your own type. See :ref:`extending_op_params` for more details).
Example of usage Example of usage
...@@ -17,24 +17,24 @@ Importation: ...@@ -17,24 +17,24 @@ Importation:
.. code-block:: python .. code-block:: python
# Import wrapper class. # Import ParamsType class.
from theano.gof import Wrapper from theano.gof import ParamsType
# If you want to use a tensor and a scalar as parameters, # If you want to use a tensor and a scalar as parameters,
# you should import required Theano types. # you should import required Theano types.
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.scalar import Scalar from theano.scalar import Scalar
In an op you create: In your Op sub-class:
.. code-block:: python .. code-block:: python
params_type = Wrapper(attr1=TensorType('int32', (False, False)), attr2=Scalar('float64')) params_type = ParamsType(attr1=TensorType('int32', (False, False)), attr2=Scalar('float64'))
If your op contains attributes ``attr1`` **and** ``attr2``, the default ``op.get_params()`` If your op contains attributes ``attr1`` **and** ``attr2``, the default ``op.get_params()``
implementation will automatically try to look for it and generate an appropriate Wrap object. implementation will automatically try to look for it and generate an appropriate Params object.
Attributes must be compatible with the corresponding types defined into the Wrapper Attributes must be compatible with the corresponding types defined into the ParamsType
(we will try to convert and downcast if needed). For example, ``your_op.attr1`` (we will try to convert and downcast if needed). In this example, ``your_op.attr1``
should be a matrix of integers, and ``your_op.attr2`` should be a matrix of integers, and ``your_op.attr2``
should be a real number (integer or floating value). should be a real number (integer or floating value).
...@@ -60,7 +60,7 @@ In ``c_code()`` implementation (with ``param = sub['params']``): ...@@ -60,7 +60,7 @@ In ``c_code()`` implementation (with ``param = sub['params']``):
/* You won't need to free them or whatever else. */ /* You won't need to free them or whatever else. */
See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_wrapper.py`` See :class:`QuadraticOpFunc` and :class:`QuadraticCOpFunc` in ``theano/gof/tests/test_params_type.py``
for complete working examples. for complete working examples.
""" """
...@@ -72,7 +72,7 @@ from theano.gof.utils import MethodNotDefined, c_cpp_keywords ...@@ -72,7 +72,7 @@ from theano.gof.utils import MethodNotDefined, c_cpp_keywords
from theano.gof import Type from theano.gof import Type
class Wrap(dict): class Params(dict):
""" """
Internal convenient class to wrap many Python objects into one Internal convenient class to wrap many Python objects into one
(this class is not safe as the hash method does not check if values are effectively hashable). (this class is not safe as the hash method does not check if values are effectively hashable).
...@@ -81,73 +81,77 @@ class Wrap(dict): ...@@ -81,73 +81,77 @@ class Wrap(dict):
.. code-block:: python .. code-block:: python
from theano.gof import Wrapper, Wrap from theano.gof import ParamsType, Params
from theano.scalar import Scalar from theano.scalar import Scalar
# You must create a Wrapper first: # You must create a ParamsType first:
wp = Wrapper(attr1=Scalar('int32'), key2=Scalar('float32'), field3=Scalar('int64')) params_type = ParamsType(attr1=Scalar('int32'),
# Then you can create a Wrap with the wrapper defined above and values for attributes. key2=Scalar('float32'),
w = Wrap(wp, attr1=1, key2=2.0, field3=3) field3=Scalar('int64'))
print(w.attr1, w.key2, w.field3) # Then you can create a Params object with
d = dict(attr1=1, key2=2, field3=-1) # the params type defined above and values for attributes.
w2 = Wrap(wp, **d) params = Params(params_type, attr1=1, key2=2.0, field3=3)
print(w2.attr1, w2.key2, w2.field3) print(params.attr1, params.key2, params.field3)
d = dict(attr1=1, key2=2.5, field3=-1)
params2 = Params(params_type, **d)
print(params2.attr1, params2.key2, params2.field3)
""" """
def __init__(self, wrapper, **kwargs): def __init__(self, params_type, **kwargs):
if not isinstance(wrapper, Wrapper): if not isinstance(params_type, ParamsType):
raise TypeError('Wrap: 1st constructor argument should be a Wrapper.') raise TypeError('Params: 1st constructor argument should be a ParamsType.')
for field in wrapper.fields: for field in params_type.fields:
if field not in kwargs: if field not in kwargs:
raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field) raise TypeError('Params: ParamsType attribute "%s" not in Params args.' % field)
super(Wrap, self).__init__(**kwargs) super(Params, self).__init__(**kwargs)
self.__dict__.update(wrapper=wrapper, signatures=None) self.__dict__.update(__params_type__=params_type, __signatures__=None)
def __repr__(self): def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())]) return 'Params(%s)' % ', '.join([('%s:%s:%s' % (k, type(self[k]).__name__, self[k])) for k in sorted(self.keys())])
def __getattr__(self, key): def __getattr__(self, key):
if key not in self: if key not in self:
raise AttributeError('Wrap: attribute "%s" does not exist.' % key) raise AttributeError('Params: attribute "%s" does not exist.' % key)
return self[key] return self[key]
def __setattr__(self, key, value): def __setattr__(self, key, value):
raise NotImplementedError('Wrap is immutable') raise NotImplementedError('Params is immutable')
def __setitem__(self, key, value): def __setitem__(self, key, value):
raise NotImplementedError('Wrap is immutable') raise NotImplementedError('Params is immutable')
def __delitem__(self, key): def __delitem__(self, key):
raise NotImplementedError('Wrap is immutable') raise NotImplementedError('Params is immutable')
def __hash__(self): def __hash__(self):
# As values are immutable, we can save data signatures the first time # As values are immutable, we can save data signatures the first time
# to not regenerate them in future hash() calls. # to not regenerate them in future hash() calls.
if self.__dict__['signatures'] is None: if self.__signatures__ is None:
self.__dict__['signatures'] = tuple( # NB: For writing, we must bypass setattr() which is always called by default by Python.
# NB: Wrapped data should have been already filtered. self.__dict__['__signatures__'] = tuple(
self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature() # NB: Params object should have been already filtered.
for i in range(self.wrapper.length) self.__params_type__.types[i].make_constant(self[self.__params_type__.fields[i]]).signature()
for i in range(self.__params_type__.length)
) )
return hash((type(self), self.wrapper) + self.signatures) return hash((type(self), self.__params_type__) + self.__signatures__)
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self.wrapper == other.wrapper and all( return (type(self) == type(other) and self.__params_type__ == other.__params_type__ and all(
# NB: Wrapped data should have been already filtered. # NB: Params object should have been already filtered.
self.wrapper.types[i].values_eq(self[self.wrapper.fields[i]], other[self.wrapper.fields[i]]) self.__params_type__.types[i].values_eq(self[self.__params_type__.fields[i]], other[self.__params_type__.fields[i]])
for i in range(self.wrapper.length) for i in range(self.__params_type__.length)
)) ))
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
class Wrapper(Type): class ParamsType(Type):
""" """
This class can create a struct of Theano types (like TensorType, GpuArrayType, etc.) This class can create a struct of Theano types (like TensorType, GpuArrayType, etc.)
to be used as a convenience op parameter wrapping many data. to be used as a convenience op parameter wrapping many data.
Wrapper constructor takes key-value args. ParamsType constructor takes key-value args.
Key will be the name of the attribute in the struct. Key will be the name of the attribute in the struct.
Value is the Theano type of this attribute, ie. an instance of (a subclass of) :class:`Type` Value is the Theano type of this attribute, ie. an instance of (a subclass of) :class:`Type`
(eg. ``TensorType('int64', (False,))``). (eg. ``TensorType('int64', (False,))``).
...@@ -170,18 +174,18 @@ class Wrapper(Type): ...@@ -170,18 +174,18 @@ class Wrapper(Type):
def __init__(self, **kwargs): def __init__(self, **kwargs):
if len(kwargs) == 0: if len(kwargs) == 0:
raise ValueError('Cannot create Wrapper from empty data.') raise ValueError('Cannot create ParamsType from empty data.')
for attribute_name in kwargs: for attribute_name in kwargs:
if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None: if re.match('^[A-Za-z_][A-Za-z0-9_]*$', attribute_name) is None:
raise AttributeError('Wrapper: attribute "%s" should be a valid identifier.' % attribute_name) raise AttributeError('ParamsType: attribute "%s" should be a valid identifier.' % attribute_name)
if attribute_name in c_cpp_keywords: if attribute_name in c_cpp_keywords:
raise SyntaxError('Wrapper: "%s" is a potential C/C++ keyword and should not be used as attribute name.' raise SyntaxError('ParamsType: "%s" is a potential C/C++ keyword and should not be used as attribute name.'
% attribute_name) % attribute_name)
type_instance = kwargs[attribute_name] type_instance = kwargs[attribute_name]
type_name = type_instance.__class__.__name__ type_name = type_instance.__class__.__name__
if not isinstance(type_instance, Type): if not isinstance(type_instance, Type):
raise TypeError('Wrapper: attribute "%s" should inherit from Theano Type, got "%s".' raise TypeError('ParamsType: attribute "%s" should inherit from Theano Type, got "%s".'
% (attribute_name, type_name)) % (attribute_name, type_name))
self.length = len(kwargs) self.length = len(kwargs)
...@@ -190,7 +194,7 @@ class Wrapper(Type): ...@@ -190,7 +194,7 @@ class Wrapper(Type):
self.name = self.generate_struct_name() self.name = self.generate_struct_name()
def __repr__(self): def __repr__(self):
return 'Wrapper<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)]) return 'ParamsType<%s>' % ', '.join([('%s:%s' % (self.fields[i], self.types[i])) for i in range(self.length)])
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self.fields == other.fields and self.types == other.types) return (type(self) == type(other) and self.fields == other.fields and self.types == other.types)
...@@ -201,26 +205,22 @@ class Wrapper(Type): ...@@ -201,26 +205,22 @@ class Wrapper(Type):
def generate_struct_name(self): def generate_struct_name(self):
# This method tries to generate an unique name for the current instance. # This method tries to generate an unique name for the current instance.
# This name is intended to be used as struct name in C code and as constant # This name is intended to be used as struct name in C code and as constant
# definition to check if a similar Wrapper has already been created # definition to check if a similar ParamsType has already been created
# (see c_support_code() below). # (see c_support_code() below).
fields_string = ','.join(self.fields).encode('utf-8') fields_string = ','.join(self.fields).encode('utf-8')
types_string = ','.join(str(t) for t in self.types).encode('utf-8') types_string = ','.join(str(t) for t in self.types).encode('utf-8')
fields_hex = hashlib.md5(fields_string).hexdigest() fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest() types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_%s_%s' % (fields_hex, types_hex) return '_Params_%s_%s' % (fields_hex, types_hex)
def wrap_data(self, data, strict, allow_downcast): # Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes.
# Try to wrap data. Raise an exception if data does not respect the Wrapper's contract.
wrap_instance = dict()
for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return data if (strict or isinstance(data, Wrap)) else Wrap(self, **wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, Wrap): if strict and not isinstance(data, Params):
raise TypeError('%s: strict mode: data should be an instance of Wrap.' % self) raise TypeError('%s: strict mode: data should be an instance of Params.' % self)
return self.wrap_data(data, strict, allow_downcast) # Filter data attributes to check if they respect the ParamsType's contract.
filtered = {self.fields[i]: self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
for i in range(self.length)}
return data if (strict or isinstance(data, Params)) else Params(self, **filtered)
def values_eq(self, a, b): def values_eq(self, a, b):
return all(self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])) return all(self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i]))
...@@ -341,7 +341,7 @@ class Wrapper(Type): ...@@ -341,7 +341,7 @@ class Wrapper(Type):
%s %s
// Default case. // Default case.
default: default:
PyErr_Format(PyExc_TypeError, "Wrapper: no extraction defined for a field %%d.", field_pos); PyErr_Format(PyExc_TypeError, "ParamsType: no extraction defined for a field %%d.", field_pos);
this->setErrorOccurred(); this->setErrorOccurred();
break; break;
} }
...@@ -394,8 +394,7 @@ class Wrapper(Type): ...@@ -394,8 +394,7 @@ class Wrapper(Type):
struct_extract_method=struct_extract_method) struct_extract_method=struct_extract_method)
def c_code_cache_version(self): def c_code_cache_version(self):
wrapper_c_code_version = (1, 6) return ((1, 7), tuple(t.c_code_cache_version() for t in self.types))
return (wrapper_c_code_version, tuple(t.c_code_cache_version() for t in self.types))
# As this struct has constructor and destructor, it could be instanciated on stack, # As this struct has constructor and destructor, it could be instanciated on stack,
# but current implementations of C ops will then pass the instance by value at functions, # but current implementations of C ops will then pass the instance by value at functions,
...@@ -426,20 +425,20 @@ class Wrapper(Type): ...@@ -426,20 +425,20 @@ class Wrapper(Type):
const char* fields[] = {%(fields_list)s}; const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "Wrapper: expected an object, not None."); PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
%(fail)s %(fail)s
} }
for (int i = 0; i < %(length)s; ++i) { for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyDict_GetItemString(py_%(name)s, fields[i]); PyObject* o = PyDict_GetItemString(py_%(name)s, fields[i]);
if (o == NULL) { if (o == NULL) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]); PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s %(fail)s
} }
%(name)s->extract(o, i); %(name)s->extract(o, i);
if (%(name)s->errorOccurred()) { if (%(name)s->errorOccurred()) {
/* The extract code from attribute type should have already raised a Python exception, /* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */ * so we just print the attribute name in stderr. */
fprintf(stderr, "\\nWrapper: error when extracting value for attribute \\"%%s\\".\\n", fields[i]); fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%%s\\".\\n", fields[i]);
%(fail)s %(fail)s
} }
} }
......
...@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply ...@@ -6,7 +6,7 @@ from theano.gof import Op, COp, Apply
from theano import Generic from theano import Generic
from theano.scalar import Scalar from theano.scalar import Scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.gof import Wrapper, Wrap from theano.gof import ParamsType, Params
from theano import tensor from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -18,9 +18,9 @@ generic_type = Generic() ...@@ -18,9 +18,9 @@ generic_type = Generic()
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params. # A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params.
class QuadraticOpFunc(Op): class QuadraticOpFunc(Op):
__props__ = ('a', 'b', 'c') __props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d, params_type = ParamsType(a=tensor_type_0d,
b=scalar_type, b=scalar_type,
c=generic_type) c=generic_type)
def __init__(self, a, b, c): def __init__(self, a, b, c):
self.a = a self.a = a
...@@ -93,9 +93,9 @@ class QuadraticOpFunc(Op): ...@@ -93,9 +93,9 @@ class QuadraticOpFunc(Op):
# Same op as above, but implemented as a COp (with C code in an external file). # Same op as above, but implemented as a COp (with C code in an external file).
class QuadraticCOpFunc(COp): class QuadraticCOpFunc(COp):
__props__ = ('a', 'b', 'c') __props__ = ('a', 'b', 'c')
params_type = Wrapper(a=tensor_type_0d, params_type = ParamsType(a=tensor_type_0d,
b=scalar_type, b=scalar_type,
c=generic_type) c=generic_type)
def __init__(self, a, b, c): def __init__(self, a, b, c):
super(QuadraticCOpFunc, self).__init__('test_quadratic_function.c', super(QuadraticCOpFunc, self).__init__('test_quadratic_function.c',
...@@ -114,71 +114,71 @@ class QuadraticCOpFunc(COp): ...@@ -114,71 +114,71 @@ class QuadraticCOpFunc(COp):
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
class TestWrapper(TestCase): class TestParamsType(TestCase):
def test_hash_and_eq_wrap(self): def test_hash_and_eq_params(self):
wp1 = Wrapper(a=Generic(), array=TensorType('int64', (False,)), floatting=Scalar('float64'), wp1 = ParamsType(a=Generic(), array=TensorType('int64', (False,)), floatting=Scalar('float64'),
npy_scalar=TensorType('float64', tuple())) npy_scalar=TensorType('float64', tuple()))
wp2 = Wrapper(a=Generic(), array=TensorType('int64', (False,)), floatting=Scalar('float64'), wp2 = ParamsType(a=Generic(), array=TensorType('int64', (False,)), floatting=Scalar('float64'),
npy_scalar=TensorType('float64', tuple())) npy_scalar=TensorType('float64', tuple()))
w1 = Wrap(wp1, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w1 = Params(wp1, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
w2 = Wrap(wp2, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Params(wp2, a=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 == w2 assert w1 == w2
assert not (w1 != w2) assert not (w1 != w2)
assert hash(w1) == hash(w2) assert hash(w1) == hash(w2)
# Changing attributes names only (a -> other_name). # Changing attributes names only (a -> other_name).
wp2_other = Wrapper(other_name=Generic(), array=TensorType('int64', (False,)), floatting=Scalar('float64'), wp2_other = ParamsType(other_name=Generic(), array=TensorType('int64', (False,)), floatting=Scalar('float64'),
npy_scalar=TensorType('float64', tuple())) npy_scalar=TensorType('float64', tuple()))
w2 = Wrap(wp2_other, other_name=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Params(wp2_other, other_name=1, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2 assert w1 != w2
# Changing attributes values only (now a=2). # Changing attributes values only (now a=2).
w2 = Wrap(wp2, a=2, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Params(wp2, a=2, array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2 assert w1 != w2
# Changing NumPy array values (5 -> -5). # Changing NumPy array values (5 -> -5).
w2 = Wrap(wp2, a=1, array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Params(wp2, a=1, array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2 assert w1 != w2
def test_hash_and_eq_wrapper(self): def test_hash_and_eq_params_type(self):
w1 = Wrapper(a1=TensorType('int64', (False, False)), w1 = ParamsType(a1=TensorType('int64', (False, False)),
a2=TensorType('int64', (False, True, False, False, True)), a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic()) a3=Generic())
w2 = Wrapper(a1=TensorType('int64', (False, False)), w2 = ParamsType(a1=TensorType('int64', (False, False)),
a2=TensorType('int64', (False, True, False, False, True)), a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic()) a3=Generic())
assert w1 == w2 assert w1 == w2
assert not (w1 != w2) assert not (w1 != w2)
assert hash(w1) == hash(w2) assert hash(w1) == hash(w2)
assert w1.name == w2.name assert w1.name == w2.name
# Changing attributes names only. # Changing attributes names only.
w2 = Wrapper(a1=TensorType('int64', (False, False)), w2 = ParamsType(a1=TensorType('int64', (False, False)),
other_name=TensorType('int64', (False, True, False, False, True)), # a2 -> other_name other_name=TensorType('int64', (False, True, False, False, True)), # a2 -> other_name
a3=Generic()) a3=Generic())
assert w1 != w2 assert w1 != w2
# Changing attributes types only. # Changing attributes types only.
w2 = Wrapper(a1=TensorType('int64', (False, False)), w2 = ParamsType(a1=TensorType('int64', (False, False)),
a2=Generic(), # changing class a2=Generic(), # changing class
a3=Generic()) a3=Generic())
assert w1 != w2 assert w1 != w2
# Changing attributes types characteristics only. # Changing attributes types characteristics only.
w2 = Wrapper(a1=TensorType('int64', (False, True)), # changing broadcasting w2 = ParamsType(a1=TensorType('int64', (False, True)), # changing broadcasting
a2=TensorType('int64', (False, True, False, False, True)), a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic()) a3=Generic())
assert w1 != w2 assert w1 != w2
def test_wrapper_filtering(self): def test_params_type_filtering(self):
shape_tensor5 = (1, 2, 2, 3, 2) shape_tensor5 = (1, 2, 2, 3, 2)
size_tensor5 = shape_tensor5[0] * shape_tensor5[1] * shape_tensor5[2] * shape_tensor5[3] * shape_tensor5[4] size_tensor5 = shape_tensor5[0] * shape_tensor5[1] * shape_tensor5[2] * shape_tensor5[3] * shape_tensor5[4]
random_tensor = numpy.random.normal(size=size_tensor5).reshape(shape_tensor5) random_tensor = numpy.random.normal(size=size_tensor5).reshape(shape_tensor5)
w = Wrapper(a1=TensorType('int32', (False, False)), w = ParamsType(a1=TensorType('int32', (False, False)),
a2=TensorType('float64', (False, False, False, False, False)), a2=TensorType('float64', (False, False, False, False, False)),
a3=Generic()) a3=Generic())
# With a value that does not match the wrapper. # With a value that does not match the params type.
o = Wrap(w, o = Params(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'), a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'),
a2=random_tensor.astype('float32'), a2=random_tensor.astype('float32'),
a3=2000) a3=2000)
# should fail (o.a1 is not int32, o.a2 is not float64) # should fail (o.a1 is not int32, o.a2 is not float64)
self.assertRaises(TypeError, w.filter, o, True) self.assertRaises(TypeError, w.filter, o, True)
# should fail (o.a1 is not int32, o.a2 is not float64, and downcast is disallowed) # should fail (o.a1 is not int32, o.a2 is not float64, and downcast is disallowed)
...@@ -186,31 +186,31 @@ class TestWrapper(TestCase): ...@@ -186,31 +186,31 @@ class TestWrapper(TestCase):
# Should pass. # Should pass.
w.filter(o, strict=False, allow_downcast=True) w.filter(o, strict=False, allow_downcast=True)
# With a value that matches the wrapper. # With a value that matches the params type.
o1 = Wrap(w, o1 = Params(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'), a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float64'), a2=random_tensor.astype('float64'),
a3=2000) a3=2000)
# All should pass. # All should pass.
w.filter(o1, strict=True) w.filter(o1, strict=True)
w.filter(o1, strict=False, allow_downcast=False) w.filter(o1, strict=False, allow_downcast=False)
w.filter(o1, strict=False, allow_downcast=True) w.filter(o1, strict=False, allow_downcast=True)
# Check values_eq and values_eq_approx. # Check values_eq and values_eq_approx.
o2 = Wrap(w, o2 = Params(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'), a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float64'), a2=random_tensor.astype('float64'),
a3=2000) a3=2000)
assert w.values_eq(o1, o2) assert w.values_eq(o1, o2)
assert w.values_eq_approx(o1, o2) assert w.values_eq_approx(o1, o2)
# Check value_eq_approx. # Check value_eq_approx.
# NB: I don't know exactly which kind of differences is rejected by values_eq but accepted by values_eq_approx. # NB: I don't know exactly which kind of differences is rejected by values_eq but accepted by values_eq_approx.
# So, I just play a little with float values. # So, I just play a little with float values.
o3 = Wrap(w, o3 = Params(w,
a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'), a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=(random_tensor.astype('float32') * 10 / 2.2 * 2.19999999999 / 10).astype('float64'), a2=(random_tensor.astype('float32') * 10 / 2.2 * 2.19999999999 / 10).astype('float64'),
a3=2000.0 - 0.00000000000000001) a3=2000.0 - 0.00000000000000001)
assert w.values_eq_approx(o1, o3) assert w.values_eq_approx(o1, o3)
def test_op_params(self): def test_op_params(self):
......
...@@ -26,7 +26,7 @@ int APPLY_SPECIFIC(quadratic_function)(PyArrayObject* tensor, DTYPE_INPUT_0 a, D ...@@ -26,7 +26,7 @@ int APPLY_SPECIFIC(quadratic_function)(PyArrayObject* tensor, DTYPE_INPUT_0 a, D
return 0; return 0;
} }
int APPLY_SPECIFIC(compute_quadratic)(PyArrayObject* X, PyArrayObject** Y, APPLY_SPECIFIC_WRAPPER* coeff) { int APPLY_SPECIFIC(compute_quadratic)(PyArrayObject* X, PyArrayObject** Y, PARAMS_TYPE* coeff) {
DTYPE_INPUT_0 a = (DTYPE_INPUT_0) (*(DTYPE_PARAM_a*) PyArray_GETPTR1(coeff->a, 0)); // 0-D TensorType. DTYPE_INPUT_0 a = (DTYPE_INPUT_0) (*(DTYPE_PARAM_a*) PyArray_GETPTR1(coeff->a, 0)); // 0-D TensorType.
DTYPE_INPUT_0 b = coeff->b; // Scalar. DTYPE_INPUT_0 b = coeff->b; // Scalar.
DTYPE_INPUT_0 c = (DTYPE_INPUT_0) PyFloat_AsDouble(coeff->c); // Generic. DTYPE_INPUT_0 c = (DTYPE_INPUT_0) PyFloat_AsDouble(coeff->c); // Generic.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论