提交 4f0249a5 authored 作者: notoraptor's avatar notoraptor

Move code into theano/gof.

Fix Python3 compat. Add comment to explain the code. Update unit tests.
上级 5a14c0c4
from __future__ import absolute_import, print_function, division
from .wrapper import Wrapper, Wrap
from __future__ import absolute_import, print_function, division
......@@ -11,14 +11,14 @@ from theano import tensor
from theano.tests import unittest_tools as utt
dtype = config.floatX
ScalarType = TensorType(dtype, tuple())
scalar_type = TensorType(dtype, tuple())
# A test op to compute `y = a*x^2 + bx + c` for any tensor x,
# such that a, b, c are parameters of that op.
class QuadraticFunction(Op):
__props__ = ('a', 'b', 'c')
params_type = Wrapper(a=ScalarType, b=ScalarType, c=ScalarType)
params_type = Wrapper(a=scalar_type, b=scalar_type, c=scalar_type)
def __init__(self, a, b, c):
self.a = a
......@@ -43,13 +43,13 @@ class QuadraticFunction(Op):
def c_support_code_apply(self, node, name):
float_type = node.inputs[0].type.dtype_specs()[1]
return """
/* Computes: x = a*x*x + b*x + c for x in matrix. */
int quadratic_%(float_type)s(PyArrayObject* matrix, %(float_type)s a, %(float_type)s b, %(float_type)s c) {
NpyIter* iterator = NpyIter_New(matrix,
/* Computes: x = a*x*x + b*x + c for x in tensor. */
int quadratic_%(float_type)s(PyArrayObject* tensor, %(float_type)s a, %(float_type)s b, %(float_type)s c) {
NpyIter* iterator = NpyIter_New(tensor,
NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK,
NPY_KEEPORDER, NPY_NO_CASTING, NULL);
if(iterator == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Unable to iterate over a matrix for an elemwise operation.");
PyErr_SetString(PyExc_RuntimeError, "Unable to iterate over a tensor for an elemwise operation.");
return -1;
}
NpyIter_IterNextFunc* get_next = NpyIter_GetIterNext(iterator, NULL);
......@@ -130,6 +130,7 @@ class TestWrapper(TestCase):
a3=Generic())
assert w1 == w2
assert hash(w1) == hash(w2)
assert w1.name == w2.name
# Changing attributes names only.
w2 = Wrapper(a1=TensorType('int64', (False, False)),
other_name=TensorType('int64', (False, True, False, False, True)),
......@@ -148,7 +149,7 @@ class TestWrapper(TestCase):
def test_wrapper_filtering(self):
shape_tensor5 = (1, 2, 2, 3, 2)
size_tensor5 = reduce(lambda x, y: x * y, shape_tensor5, 1)
size_tensor5 = shape_tensor5[0] * shape_tensor5[1] * shape_tensor5[2] * shape_tensor5[3] * shape_tensor5[4]
random_tensor = numpy.random.normal(size=size_tensor5).astype('float64').reshape(shape_tensor5)
# With a wrapper that does not match the value.
......@@ -188,7 +189,7 @@ class TestWrapper(TestCase):
assert w.values_eq_approx(o, o3)
def test_wrapper(self):
def test_op_params(self):
a, b, c = 2, 3, -7
x = tensor.matrix()
y = QuadraticFunction(a, b, c)(x)
......
......@@ -56,6 +56,9 @@ class Wrap(object):
def __init__(self, **kwargs):
if len(kwargs) == 0:
raise TypeError('Wrap: cannot wrap empty data.')
# We want to use only the params provided in kwargs to hash the object,
# so I prefer to put them into a separate attribute (self.data) instead
# of directly in self.__dict__, to avoid confusion with builtin fields.
super(Wrap, self).__setattr__('data', kwargs)
def __repr__(self):
......@@ -79,16 +82,20 @@ class Wrap(object):
types += (type(self.data[k]),)
if isinstance(self.data[k], numpy.ndarray):
if len(self.data[k].shape) == 0:
# NumPy scalar is not iterable, so we put it into a tuple.
attributes += (numpy.asscalar(self.data[k]),)
else:
# NumPy non-0-D arrays are iterable, so we append it as a tuple.
attributes += tuple(self.data[k])
else:
try:
iter(self.data[k])
except TypeError:
# Not iterable: we put it into a tuple.
attributes += (self.data[k],)
else:
attributes += tuple(self.data[k])
# Iterable: we append it directly.
attributes += self.data[k]
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
def __eq__(self, other):
......
......@@ -36,7 +36,6 @@ whitelist_flake8 = [
"compat/six.py", # This is bundled code that will be deleted, don't fix it
"__init__.py",
"tests/__init__.py",
"common/__init__.py",
"compile/__init__.py",
"compile/sandbox/__init__.py",
"compile/tests/__init__.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论