提交 4f0249a5 authored 作者: notoraptor's avatar notoraptor

Move code into theano/gof.

Fix Python3 compat. Add comment to explain the code. Update unit tests.
上级 5a14c0c4
from __future__ import absolute_import, print_function, division
from .wrapper import Wrapper, Wrap
from __future__ import absolute_import, print_function, division
...@@ -11,14 +11,14 @@ from theano import tensor ...@@ -11,14 +11,14 @@ from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
dtype = config.floatX dtype = config.floatX
ScalarType = TensorType(dtype, tuple()) scalar_type = TensorType(dtype, tuple())
# A test op to compute `y = a*x^2 + bx + c` for any tensor x, # A test op to compute `y = a*x^2 + bx + c` for any tensor x,
# such that a, b, c are parameters of that op. # such that a, b, c are parameters of that op.
class QuadraticFunction(Op): class QuadraticFunction(Op):
__props__ = ('a', 'b', 'c') __props__ = ('a', 'b', 'c')
params_type = Wrapper(a=ScalarType, b=ScalarType, c=ScalarType) params_type = Wrapper(a=scalar_type, b=scalar_type, c=scalar_type)
def __init__(self, a, b, c): def __init__(self, a, b, c):
self.a = a self.a = a
...@@ -43,13 +43,13 @@ class QuadraticFunction(Op): ...@@ -43,13 +43,13 @@ class QuadraticFunction(Op):
def c_support_code_apply(self, node, name): def c_support_code_apply(self, node, name):
float_type = node.inputs[0].type.dtype_specs()[1] float_type = node.inputs[0].type.dtype_specs()[1]
return """ return """
/* Computes: x = a*x*x + b*x + c for x in matrix. */ /* Computes: x = a*x*x + b*x + c for x in tensor. */
int quadratic_%(float_type)s(PyArrayObject* matrix, %(float_type)s a, %(float_type)s b, %(float_type)s c) { int quadratic_%(float_type)s(PyArrayObject* tensor, %(float_type)s a, %(float_type)s b, %(float_type)s c) {
NpyIter* iterator = NpyIter_New(matrix, NpyIter* iterator = NpyIter_New(tensor,
NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK, NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK,
NPY_KEEPORDER, NPY_NO_CASTING, NULL); NPY_KEEPORDER, NPY_NO_CASTING, NULL);
if(iterator == NULL) { if(iterator == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Unable to iterate over a matrix for an elemwise operation."); PyErr_SetString(PyExc_RuntimeError, "Unable to iterate over a tensor for an elemwise operation.");
return -1; return -1;
} }
NpyIter_IterNextFunc* get_next = NpyIter_GetIterNext(iterator, NULL); NpyIter_IterNextFunc* get_next = NpyIter_GetIterNext(iterator, NULL);
...@@ -130,6 +130,7 @@ class TestWrapper(TestCase): ...@@ -130,6 +130,7 @@ class TestWrapper(TestCase):
a3=Generic()) a3=Generic())
assert w1 == w2 assert w1 == w2
assert hash(w1) == hash(w2) assert hash(w1) == hash(w2)
assert w1.name == w2.name
# Changing attributes names only. # Changing attributes names only.
w2 = Wrapper(a1=TensorType('int64', (False, False)), w2 = Wrapper(a1=TensorType('int64', (False, False)),
other_name=TensorType('int64', (False, True, False, False, True)), other_name=TensorType('int64', (False, True, False, False, True)),
...@@ -148,7 +149,7 @@ class TestWrapper(TestCase): ...@@ -148,7 +149,7 @@ class TestWrapper(TestCase):
def test_wrapper_filtering(self): def test_wrapper_filtering(self):
shape_tensor5 = (1, 2, 2, 3, 2) shape_tensor5 = (1, 2, 2, 3, 2)
size_tensor5 = reduce(lambda x, y: x * y, shape_tensor5, 1) size_tensor5 = shape_tensor5[0] * shape_tensor5[1] * shape_tensor5[2] * shape_tensor5[3] * shape_tensor5[4]
random_tensor = numpy.random.normal(size=size_tensor5).astype('float64').reshape(shape_tensor5) random_tensor = numpy.random.normal(size=size_tensor5).astype('float64').reshape(shape_tensor5)
# With a wrapper that does not match the value. # With a wrapper that does not match the value.
...@@ -188,7 +189,7 @@ class TestWrapper(TestCase): ...@@ -188,7 +189,7 @@ class TestWrapper(TestCase):
assert w.values_eq_approx(o, o3) assert w.values_eq_approx(o, o3)
def test_wrapper(self): def test_op_params(self):
a, b, c = 2, 3, -7 a, b, c = 2, 3, -7
x = tensor.matrix() x = tensor.matrix()
y = QuadraticFunction(a, b, c)(x) y = QuadraticFunction(a, b, c)(x)
......
...@@ -56,6 +56,9 @@ class Wrap(object): ...@@ -56,6 +56,9 @@ class Wrap(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
if len(kwargs) == 0: if len(kwargs) == 0:
raise TypeError('Wrap: cannot wrap empty data.') raise TypeError('Wrap: cannot wrap empty data.')
# We want to use only the params provided in kwargs to hash the object,
# so I prefer to put them into a separate attribute (self.data) instead
# of directly in self.__dict__, to avoid confusion with builtin fields.
super(Wrap, self).__setattr__('data', kwargs) super(Wrap, self).__setattr__('data', kwargs)
def __repr__(self): def __repr__(self):
...@@ -79,16 +82,20 @@ class Wrap(object): ...@@ -79,16 +82,20 @@ class Wrap(object):
types += (type(self.data[k]),) types += (type(self.data[k]),)
if isinstance(self.data[k], numpy.ndarray): if isinstance(self.data[k], numpy.ndarray):
if len(self.data[k].shape) == 0: if len(self.data[k].shape) == 0:
# NumPy scalar is not iterable, so we put it into a tuple.
attributes += (numpy.asscalar(self.data[k]),) attributes += (numpy.asscalar(self.data[k]),)
else: else:
# NumPy non-0-D arrays are iterable, so we append it as a tuple.
attributes += tuple(self.data[k]) attributes += tuple(self.data[k])
else: else:
try: try:
iter(self.data[k]) iter(self.data[k])
except TypeError: except TypeError:
# Not iterable: we put it into a tuple.
attributes += (self.data[k],) attributes += (self.data[k],)
else: else:
attributes += tuple(self.data[k]) # Iterable: we append it directly.
attributes += self.data[k]
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes)) return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
def __eq__(self, other): def __eq__(self, other):
......
...@@ -36,7 +36,6 @@ whitelist_flake8 = [ ...@@ -36,7 +36,6 @@ whitelist_flake8 = [
"compat/six.py", # This is bundled code that will be deleted, don't fix it "compat/six.py", # This is bundled code that will be deleted, don't fix it
"__init__.py", "__init__.py",
"tests/__init__.py", "tests/__init__.py",
"common/__init__.py",
"compile/__init__.py", "compile/__init__.py",
"compile/sandbox/__init__.py", "compile/sandbox/__init__.py",
"compile/tests/__init__.py", "compile/tests/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论