提交 1a2ef20a authored 作者: notoraptor's avatar notoraptor

Reimplement Wrap from dict as superclass.

上级 adaa5f57
...@@ -55,7 +55,7 @@ from theano.tensor.utils import hash_from_ndarray ...@@ -55,7 +55,7 @@ from theano.tensor.utils import hash_from_ndarray
# http://fr.cppreference.com/w/c/keyword # http://fr.cppreference.com/w/c/keyword
class Wrap(object): class Wrap(dict):
""" """
Internal convenient class to wrap many Python objects into one Internal convenient class to wrap many Python objects into one
(this class is not safe as the hash method does not check if values are effectively hashable). (this class is not safe as the hash method does not check if values are effectively hashable).
...@@ -70,56 +70,49 @@ class Wrap(object): ...@@ -70,56 +70,49 @@ class Wrap(object):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(Wrap, self).__init__(**kwargs)
if len(kwargs) == 0: if len(kwargs) == 0:
raise TypeError('Wrap: cannot wrap empty data.') raise TypeError('Wrap: cannot wrap empty data.')
# We want to use only the params provided in kwargs to hash the object,
# so I prefer to put them into a separate attribute (self.data) instead
# of directly in self.__dict__, to avoid confusion with builtin fields.
super(Wrap, self).__setattr__('data', kwargs)
def __repr__(self): def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self.data[k]))) for k in sorted(self.data.keys())]) return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())])
def __getattr__(self, key): def __getattr__(self, key):
if key not in self.data: if key not in self:
raise AttributeError('Wrap: attribute "%s" does not exist.' % key) raise AttributeError('Wrap: attribute "%s" does not exist.' % key)
return self.data[key] return self[key]
def __setattr__(self, key, value):
if key not in self.data:
raise AttributeError('Wrap: attribute "%s" does not exist.' % key)
self.data[key] = value
def __hash__(self): def __hash__(self):
keys = sorted(self.data.keys()) keys = sorted(self.keys())
types = [] types = []
attributes = [] attributes = []
for k in keys: for k in keys:
types += (type(self.data[k]),) types += (type(self[k]),)
if isinstance(self.data[k], numpy.ndarray): if isinstance(self[k], numpy.ndarray):
# Note: hash_from_ndarray returns a string, so the hash is not yet complete # Note: hash_from_ndarray returns a string, so the hash is not yet complete
# (__hash__ must return an integer). # (__hash__ must return an integer).
attributes += (hash_from_ndarray(self.data[k]),) attributes += (hash_from_ndarray(self[k]),)
else: else:
# No checking, data should be hashable. # No checking, data should be hashable.
attributes += (self.data[k],) attributes += (self[k],)
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes)) return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) != type(other) or len(self) != len(other):
return False return False
for k in self.data: for k in self:
if (k not in other.data or if k not in other or not (isinstance(self[k], type(other[k])) and isinstance(other[k], type(self[k]))):
not isinstance(self.data[k], type(other.data[k])) or
not isinstance(other.data[k], type(self.data[k]))):
return False return False
if isinstance(self.data[k], numpy.ndarray): if isinstance(self[k], numpy.ndarray) or isinstance(other[k], numpy.ndarray):
if not numpy.allclose(self.data[k], other.data[k]): if not numpy.allclose(self[k], other[k]):
return False return False
elif self.data[k] != other.data[k]: elif self[k] != other[k]:
return False return False
return True return True
def __ne__(self, other):
return not self.__eq__(other)
class Wrapper(Type): class Wrapper(Type):
""" """
...@@ -191,10 +184,8 @@ class Wrapper(Type): ...@@ -191,10 +184,8 @@ class Wrapper(Type):
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes. # Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
if isinstance(data, dict): if strict and not isinstance(data, Wrap):
if strict: raise TypeError('%s: strict mode: data should be an instance of Wrap.' % self)
raise TypeError('%s: strict mode: data should be an object, not a dict.' % self)
data = Wrap(**data)
return self.check_that_values_are_compatible(data, strict, allow_downcast) return self.check_that_values_are_compatible(data, strict, allow_downcast)
def values_eq(self, a, b): def values_eq(self, a, b):
...@@ -384,7 +375,7 @@ class Wrapper(Type): ...@@ -384,7 +375,7 @@ class Wrapper(Type):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, 3) return (1, 4)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
struct_name = self.name struct_name = self.name
...@@ -412,7 +403,7 @@ class Wrapper(Type): ...@@ -412,7 +403,7 @@ class Wrapper(Type):
%(fail)s %(fail)s
} }
for (int i = 0; i < %(length)s; ++i) { for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]); PyObject* o = PyDict_GetItemString(py_%(name)s, fields[i]);
if (o == NULL) { if (o == NULL) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]); PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s %(fail)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论