提交 5a14c0c4 authored 作者: notoraptor's avatar notoraptor

Rewrite hash methods (still too heavy, I think).

Add new tests. Fix typos.
上级 be97f3ee
from __future__ import absolute_import, print_function, division
import theano
import numpy
from unittest import TestCase
from theano.gof import Op, Apply
from theano import Generic
from theano.tensor import TensorType
from theano.common import Wrapper, Wrap
from theano import config
......@@ -98,7 +100,95 @@ class QuadraticFunction(Op):
""" % locals()
def test_wrapper():
class TestWrapper(TestCase):
def test_wrap_instances(self):
w1 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 == w2
assert hash(w1) == hash(w2)
assert all(hasattr(w1, key) for key in ('a', 'b', 'array', 'floatting', 'npy_scalar'))
# Changing attributes names only.
w2 = Wrap(other_name=1, b='test string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
# Changing attributes types only.
w2 = Wrap(a=1, b='test string', array=[1, 2, 4, 5, 7], floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
# Changing attributes values only.
w2 = Wrap(a=1, b='string', array=numpy.asarray([1, 2, 4, 5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
# Changing NumPy array values.
w2 = Wrap(a=1, b='test string', array=numpy.asarray([1, 2, 4, -5, 7]), floatting=-4.5, npy_scalar=numpy.asarray(12))
assert w1 != w2
def test_wrapper_instances(self):
w1 = Wrapper(a1=TensorType('int64', (False, False)),
a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic())
w2 = Wrapper(a1=TensorType('int64', (False, False)),
a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic())
assert w1 == w2
assert hash(w1) == hash(w2)
# Changing attributes names only.
w2 = Wrapper(a1=TensorType('int64', (False, False)),
other_name=TensorType('int64', (False, True, False, False, True)),
a3=Generic())
assert w1 != w2
# Changing attributes types only.
w2 = Wrapper(a1=TensorType('int64', (False, False)),
a2=Generic(), # changing class
a3=Generic())
assert w1 != w2
# Changing attributes types characteristics only.
w2 = Wrapper(a1=TensorType('int64', (False, True)), # changing broadcasting
a2=TensorType('int64', (False, True, False, False, True)),
a3=Generic())
assert w1 != w2
def test_wrapper_filtering(self):
shape_tensor5 = (1, 2, 2, 3, 2)
size_tensor5 = reduce(lambda x, y: x * y, shape_tensor5, 1)
random_tensor = numpy.random.normal(size=size_tensor5).astype('float64').reshape(shape_tensor5)
# With a wrapper that does not match the value.
w = Wrapper(a1=TensorType('int64', (False, False)),
a2=TensorType('float32', (False, False, False, False, False)),
a3=Generic())
o = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'),
a2=random_tensor,
a3=2000)
# should fail (a2 is not float32)
self.assertRaises(TypeError, w.filter, o, True)
# should fail (a2 is float64, but downcast to float32 is disallowed)
self.assertRaises(TypeError, w.filter, o, False, False)
# Should pass.
w.filter(o, strict=False, allow_downcast=True)
# With a wrapper that matches the value.
w = Wrapper(a1=TensorType('int64', (False, False)),
a2=TensorType('float64', (False, False, False, False, False)),
a3=Generic())
# All should pass.
w.filter(o, strict=True)
w.filter(o, strict=False, allow_downcast=False)
w.filter(o, strict=False, allow_downcast=True)
# Check value_eq and value_eq_approx.
o2 = Wrap(a1=numpy.asarray([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]).astype('int64'),
a2=random_tensor,
a3=2000)
assert w.values_eq(o, o2)
assert w.values_eq_approx(o, o2)
# Check value_eq_approx.
o3 = Wrap(a1=numpy.asarray([[1, 2.0, 3.000, 4, 5.0, 6], [7, 8, 9, 10, 11, 12]]).astype('int32'),
a2=random_tensor.astype('float32'),
a3=2000.0)
assert w.values_eq_approx(o, o3)
def test_wrapper(self):
a, b, c = 2, 3, -7
x = tensor.matrix()
y = QuadraticFunction(a, b, c)(x)
......
......@@ -29,6 +29,7 @@ See theano/common/tests/test_wrapper.py for a complete working example.
from __future__ import absolute_import, print_function, division
import re
import hashlib
import numpy
from theano.gof.utils import MethodNotDefined
from theano.gof import Type
from theano.gof.cmodule import GCC_compiler as compiler
......@@ -71,10 +72,40 @@ class Wrap(object):
self.data[key] = value
def __hash__(self):
return hash(frozenset(self.data.items()))
keys = sorted(self.data.keys())
types = []
attributes = []
for k in keys:
types += (type(self.data[k]),)
if isinstance(self.data[k], numpy.ndarray):
if len(self.data[k].shape) == 0:
attributes += (numpy.asscalar(self.data[k]),)
else:
attributes += tuple(self.data[k])
else:
try:
iter(self.data[k])
except TypeError:
attributes += (self.data[k],)
else:
attributes += tuple(self.data[k])
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
def __eq__(self, other):
return type(self) == type(other) and self.data == other.data
if type(self) != type(other):
return False
for k in self.data:
if (k not in other.data or
not isinstance(self.data[k], type(other.data[k])) or
not isinstance(other.data[k], type(self.data[k]))):
return False
if isinstance(self.data[k], numpy.ndarray):
if not numpy.allclose(self.data[k], other.data[k]):
return False
elif self.data[k] != other.data[k]:
return False
return True
# return type(self) == type(other) and self.data == other.data
class Wrapper(Type):
......@@ -178,7 +209,7 @@ class Wrapper(Type):
a = self.filter(a, strict=False)
b = self.filter(b, strict=False)
for i in range(self.length):
if not self.types[i].value_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
if not self.types[i].values_eq(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False
return True
......@@ -186,7 +217,7 @@ class Wrapper(Type):
a = self.filter(a, strict=False)
b = self.filter(b, strict=False)
for i in range(self.length):
if not self.types[i].value_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
if not self.types[i].values_eq_approx(getattr(a, self.fields[i]), getattr(b, self.fields[i])):
return False
return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论