提交 72eeee09 authored 作者: notoraptor's avatar notoraptor

Change the checking of Apply.run_params(): Previously, it checks if

`op.get_params()` exists. Now it checks if `ops.params_type` exists. Add a default implementation of `Op.get_params()` which tris to detect and filter params from the op if `op.params_type` is defined with a Wrapper, else it will raise a `MethodNotDefined` exception. Rewrite test_wrapper. The tested op now uses 3 theano types: TensorType, Scalar and Generic. Test methods are also renamed and rewritten to be more readable. Simplify `Wrap.__hash__`, as `Wrap` is now considered as internal. The ndarrays are now hashed first with `theano.tensor.utils.hash_from_ndarray`.
上级 4f0249a5
...@@ -125,8 +125,12 @@ class Apply(Node): ...@@ -125,8 +125,12 @@ class Apply(Node):
Returns the params for the node, or NoParams if no params is set. Returns the params for the node, or NoParams if no params is set.
""" """
if hasattr(self.op, 'get_params'): if hasattr(self.op, 'params_type'):
return self.op.get_params(self) try:
return self.op.get_params(self)
except theano.gof.utils.MethodNotDefined:
# If get_params if not defined, we will return NoParams.
pass
return NoParams return NoParams
def __getstate__(self): def __getstate__(self):
......
...@@ -795,6 +795,28 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -795,6 +795,28 @@ class Op(utils.object2, PureOp, CLinkerOp):
Convenience class to bundle `PureOp` and `CLinkerOp`. Convenience class to bundle `PureOp` and `CLinkerOp`.
""" """
# We add a default get_params() implementation which will try to detect params from the op
# if params_type is set to a Wrapper. If not, we raise a MethodNodDefined exception.
def get_params(self, node):
if hasattr(self, 'params_type'):
# If params_type is a Wrapper, we try to extract params from the op.
if isinstance(self.params_type, theano.gof.wrapper.Wrapper):
wrapper = self.params_type
op_has_wrap_attributes = True
for field in wrapper.fields:
if not hasattr(self, field):
op_has_wrap_attributes = False
break
if op_has_wrap_attributes:
wrap_dict = dict()
for i in range(wrapper.length):
field = wrapper.fields[i]
_type = wrapper.types[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.wrapper.Wrap(**wrap_dict)
raise theano.gof.utils.MethodNotDefined('get_params')
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
""" """
Make any special modifications that the Op needs before doing Make any special modifications that the Op needs before doing
......
...@@ -4,30 +4,31 @@ import numpy ...@@ -4,30 +4,31 @@ import numpy
from unittest import TestCase from unittest import TestCase
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano import Generic from theano import Generic
from theano.scalar import Scalar
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.common import Wrapper, Wrap from theano.gof.wrapper import Wrapper, Wrap
from theano import config from theano import config
from theano import tensor from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
dtype = config.floatX dtype = config.floatX
scalar_type = TensorType(dtype, tuple()) tensor_type_0d = TensorType(dtype, tuple())
scalar_type = Scalar(dtype)
generic_type = Generic()
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, # A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as parameters of that op.
# such that a, b, c are parameters of that op.
class QuadraticFunction(Op): class QuadraticFunction(Op):
__props__ = ('a', 'b', 'c') __props__ = ('a', 'b', 'c')
params_type = Wrapper(a=scalar_type, b=scalar_type, c=scalar_type) params_type = Wrapper(a=tensor_type_0d,
b=scalar_type,
c=generic_type)
def __init__(self, a, b, c): def __init__(self, a, b, c):
self.a = a self.a = a
self.b = b self.b = b
self.c = c self.c = c
def get_params(self, node):
return Wrap(a=self.a, b=self.b, c=self.c)
def make_node(self, x): def make_node(self, x):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
...@@ -38,7 +39,7 @@ class QuadraticFunction(Op): ...@@ -38,7 +39,7 @@ class QuadraticFunction(Op):
y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, 1) return (1, 2)
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
float_type = node.inputs[0].type.dtype_specs()[1] float_type = node.inputs[0].type.dtype_specs()[1]
...@@ -81,12 +82,9 @@ class QuadraticFunction(Op): ...@@ -81,12 +82,9 @@ class QuadraticFunction(Op):
float_typenum = numpy.dtype(node.inputs[0].type.dtype).num float_typenum = numpy.dtype(node.inputs[0].type.dtype).num
coeff_type = 'npy_' + numpy.dtype(dtype).name coeff_type = 'npy_' + numpy.dtype(dtype).name
return """ return """
PyArrayObject* o_a = %(coeff)s.a; %(float_type)s a = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(%(coeff)s.a, 0)); // 0-D TensorType.
PyArrayObject* o_b = %(coeff)s.b; %(float_type)s b = %(coeff)s.b; // Scalar.
PyArrayObject* o_c = %(coeff)s.c; %(float_type)s c = (%(float_type)s)PyFloat_AsDouble(%(coeff)s.c); // Generic.
%(float_type)s a = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(o_a, 0));
%(float_type)s b = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(o_b, 0));
%(float_type)s c = (%(float_type)s) (*(%(coeff_type)s*) PyArray_GETPTR1(o_c, 0));
Py_XDECREF(%(Y)s); Py_XDECREF(%(Y)s);
%(Y)s = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(%(X)s), PyArray_DIMS(%(X)s), %(float_typenum)s, PyArray_IS_F_CONTIGUOUS(%(X)s)); %(Y)s = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM(%(X)s), PyArray_DIMS(%(X)s), %(float_typenum)s, PyArray_IS_F_CONTIGUOUS(%(X)s));
if (PyArray_CopyInto(%(Y)s, %(X)s) != 0) { if (PyArray_CopyInto(%(Y)s, %(X)s) != 0) {
...@@ -102,7 +100,7 @@ class QuadraticFunction(Op): ...@@ -102,7 +100,7 @@ class QuadraticFunction(Op):
class TestWrapper(TestCase): class TestWrapper(TestCase):
def test_wrap_instances(self): def test_wrap_hash_and_eq(self):
w1 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w1 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 == w2 assert w1 == w2
...@@ -121,7 +119,7 @@ class TestWrapper(TestCase): ...@@ -121,7 +119,7 @@ class TestWrapper(TestCase):
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12)) w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2 assert w1 != w2
def test_wrapper_instances(self): def test_wrapper_hash_and_eq(self):
w1 = Wrapper(a1=TensorType('int64', (False, False)), w1 = Wrapper(a1=TensorType('int64', (False, False)),
a2=TensorType('int64', (False, True, False, False, True)), a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic()) a3=Generic())
...@@ -150,44 +148,44 @@ class TestWrapper(TestCase): ...@@ -150,44 +148,44 @@ class TestWrapper(TestCase):
def test_wrapper_filtering(self): def test_wrapper_filtering(self):
shape_tensor5 = (1, 2, 2, 3, 2) shape_tensor5 = (1, 2, 2, 3, 2)
size_tensor5 = shape_tensor5[0] * shape_tensor5[1] * shape_tensor5[2] * shape_tensor5[3] * shape_tensor5[4] size_tensor5 = shape_tensor5[0] * shape_tensor5[1] * shape_tensor5[2] * shape_tensor5[3] * shape_tensor5[4]
random_tensor = numpy.random.normal(size=size_tensor5).astype('float64').reshape(shape_tensor5) random_tensor = numpy.random.normal(size=size_tensor5).reshape(shape_tensor5)
# With a wrapper that does not match the value. w = Wrapper(a1=TensorType('int32', (False, False)),
w = Wrapper(a1=TensorType('int64', (False, False)), a2=TensorType('float64', (False, False, False, False, False)),
a2=TensorType('float32', (False, False, False, False, False)),
a3=Generic()) a3=Generic())
# With a value that does not match the wrapper.
o = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'), o = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'),
a2=random_tensor, a2=random_tensor.astype('float32'),
a3=2000) a3=2000)
# should fail (a2 is not float32) # should fail (o.a1 is not int32, o.a2 is not float64)
self.assertRaises(TypeError, w.filter, o, True) self.assertRaises(TypeError, w.filter, o, True)
# should fail (a2 is float64, but downcast to float32 is disallowed) # should fail (o.a1 is not int32, o.a2 is not float64, and downcast is disallowed)
self.assertRaises(TypeError, w.filter, o, False, False) self.assertRaises(TypeError, w.filter, o, False, False)
# Should pass. # Should pass.
w.filter(o, strict=False, allow_downcast=True) w.filter(o, strict=False, allow_downcast=True)
# With a wrapper that matches the value. # With a value that matches the wrapper.
w = Wrapper(a1=TensorType('int64', (False, False)), o1 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=TensorType('float64', (False, False, False, False, False)), a2=random_tensor.astype('float64'),
a3=Generic()) a3=2000)
# All should pass. # All should pass.
w.filter(o, strict=True) w.filter(o1, strict=True)
w.filter(o, strict=False, allow_downcast=False) w.filter(o1, strict=False, allow_downcast=False)
w.filter(o, strict=False, allow_downcast=True) w.filter(o1, strict=False, allow_downcast=True)
# Check value_eq and value_eq_approx. # Check value_eq and value_eq_approx.
o2 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'), o2 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor, a2=random_tensor.astype('float64'),
a3=2000) a3=2000)
assert w.values_eq(o, o2) assert w.values_eq(o1, o2)
assert w.values_eq_approx(o, o2) assert w.values_eq_approx(o1, o2)
# Check value_eq_approx. # Check value_eq_approx.
o3 = Wrap(a1=numpy.asarray([[1, 2.0, 3.000, 4, 5.0, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'), o3 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('float32'),
a2=random_tensor.astype('float32'), a2=random_tensor.astype('float64'),
a3=2000.0) a3=2000.0)
assert w.values_eq_approx(o1, o3)
assert w.values_eq_approx(o, o3)
def test_op_params(self): def test_op_params(self):
a, b, c = 2, 3, -7 a, b, c = 2, 3, -7
......
...@@ -33,6 +33,7 @@ import numpy ...@@ -33,6 +33,7 @@ import numpy
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gof import Type from theano.gof import Type
from theano.gof.cmodule import GCC_compiler as compiler from theano.gof.cmodule import GCC_compiler as compiler
from theano.tensor.utils import hash_from_ndarray
# NB: Maybe we should check if an attribute name is a C/C++ keyword, and raise an error if so. # NB: Maybe we should check if an attribute name is a C/C++ keyword, and raise an error if so.
# These are some lists of C/C++ keywords: # These are some lists of C/C++ keywords:
...@@ -42,7 +43,8 @@ from theano.gof.cmodule import GCC_compiler as compiler ...@@ -42,7 +43,8 @@ from theano.gof.cmodule import GCC_compiler as compiler
class Wrap(object): class Wrap(object):
""" """
Convenient class to wrap many Python objects into one. Internal convenient class to wrap many Python objects into one
(this class is not safe as the hash method does not check if values are effectively hashable).
Example: Example:
>>> w = Wrap(attr1=var1, attr2=var2, attri=vari) >>> w = Wrap(attr1=var1, attr2=var2, attri=vari)
...@@ -62,7 +64,7 @@ class Wrap(object): ...@@ -62,7 +64,7 @@ class Wrap(object):
super(Wrap, self).__setattr__('data', kwargs) super(Wrap, self).__setattr__('data', kwargs)
def __repr__(self): def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, self.data[k])) for k in sorted(self.data.keys())]) return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self.data[k]))) for k in sorted(self.data.keys())])
def __getattr__(self, key): def __getattr__(self, key):
if key not in self.data: if key not in self.data:
...@@ -81,21 +83,12 @@ class Wrap(object): ...@@ -81,21 +83,12 @@ class Wrap(object):
for k in keys: for k in keys:
types += (type(self.data[k]),) types += (type(self.data[k]),)
if isinstance(self.data[k], numpy.ndarray): if isinstance(self.data[k], numpy.ndarray):
if len(self.data[k].shape) == 0: # Note: hash_from_ndarray returns a string, so the hash is not yet complete
# NumPy scalar is not iterable, so we put it into a tuple. # (__hash__ must return an integer).
attributes += (numpy.asscalar(self.data[k]),) attributes += (hash_from_ndarray(self.data[k]),)
else:
# NumPy non-0-D arrays are iterable, so we append it as a tuple.
attributes += tuple(self.data[k])
else: else:
try: # No checking, data should be hashable.
iter(self.data[k]) attributes += (self.data[k],)
except TypeError:
# Not iterable: we put it into a tuple.
attributes += (self.data[k],)
else:
# Iterable: we append it directly.
attributes += self.data[k]
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes)) return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
def __eq__(self, other): def __eq__(self, other):
...@@ -166,13 +159,13 @@ class Wrapper(Type): ...@@ -166,13 +159,13 @@ class Wrapper(Type):
def generate_struct_name(self): def generate_struct_name(self):
"""" """"
This method try to generate an unique name for the current instance. This method tries to generate an unique name for the current instance.
This name is intended to be used as struct name in C code and This name is intended to be used as struct name in C code and as constant
as constant definition to check if a similar Wrapper has already been created definition to check if a similar Wrapper has already been created
(see c_support_code() below). (see c_support_code() below).
""" """
fields_string = ','.join(self.fields) fields_string = ','.join(self.fields).encode('utf-8')
types_string = ','.join(str(t) for t in self.types) types_string = ','.join(str(t) for t in self.types).encode('utf-8')
fields_hex = hashlib.md5(fields_string).hexdigest() fields_hex = hashlib.md5(fields_string).hexdigest()
types_hex = hashlib.md5(types_string).hexdigest() types_hex = hashlib.md5(types_string).hexdigest()
return '_wrapper_struct_%s_%s' % (fields_hex, types_hex) return '_wrapper_struct_%s_%s' % (fields_hex, types_hex)
...@@ -213,16 +206,14 @@ class Wrapper(Type): ...@@ -213,16 +206,14 @@ class Wrapper(Type):
return wrapped_data return wrapped_data
def values_eq(self, a, b): def values_eq(self, a, b):
a = self.filter(a, strict=False)
b = self.filter(b, strict=False)
for i in range(self.length): for i in range(self.length):
if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])): if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False return False
return True return True
def values_eq_approx(self, a, b): def values_eq_approx(self, a, b):
a = self.filter(a, strict=False) a = self.filter(a, strict=False, allow_downcast=True)
b = self.filter(b, strict=False) b = self.filter(b, strict=False, allow_downcast=True)
for i in range(self.length): for i in range(self.length):
if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])): if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论