提交 1a2ef20a authored 作者: notoraptor's avatar notoraptor

Reimplement Wrap from dict as superclass.

上级 adaa5f57
......@@ -55,7 +55,7 @@ from theano.tensor.utils import hash_from_ndarray
# http://fr.cppreference.com/w/c/keyword
class Wrap(object):
class Wrap(dict):
"""
Internal convenient class to wrap many Python objects into one
(this class is not safe as the hash method does not check if values are effectively hashable).
......@@ -70,56 +70,49 @@ class Wrap(object):
"""
def __init__(self, **kwargs):
super(Wrap, self).__init__(**kwargs)
if len(kwargs) == 0:
raise TypeError('Wrap: cannot wrap empty data.')
# We want to use only the params provided in kwargs to hash the object,
# so I prefer to put them into a separate attribute (self.data) instead
# of directly in self.__dict__, to avoid confusion with builtin fields.
super(Wrap, self).__setattr__('data', kwargs)
def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self.data[k]))) for k in sorted(self.data.keys())])
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())])
def __getattr__(self, key):
if key not in self.data:
if key not in self:
raise AttributeError('Wrap: attribute "%s" does not exist.' % key)
return self.data[key]
def __setattr__(self, key, value):
if key not in self.data:
raise AttributeError('Wrap: attribute "%s" does not exist.' % key)
self.data[key] = value
return self[key]
def __hash__(self):
keys = sorted(self.data.keys())
keys = sorted(self.keys())
types = []
attributes = []
for k in keys:
types += (type(self.data[k]),)
if isinstance(self.data[k], numpy.ndarray):
types += (type(self[k]),)
if isinstance(self[k], numpy.ndarray):
# Note: hash_from_ndarray returns a string, so the hash is not yet complete
# (__hash__ must return an integer).
attributes += (hash_from_ndarray(self.data[k]),)
attributes += (hash_from_ndarray(self[k]),)
else:
# No checking, data should be hashable.
attributes += (self.data[k],)
attributes += (self[k],)
return hash((type(self),) + tuple(keys) + tuple(types) + tuple(attributes))
def __eq__(self, other):
if type(self) != type(other):
if type(self) != type(other) or len(self) != len(other):
return False
for k in self.data:
if (k not in other.data or
not isinstance(self.data[k], type(other.data[k])) or
not isinstance(other.data[k], type(self.data[k]))):
for k in self:
if k not in other or not (isinstance(self[k], type(other[k])) and isinstance(other[k], type(self[k]))):
return False
if isinstance(self.data[k], numpy.ndarray):
if not numpy.allclose(self.data[k], other.data[k]):
if isinstance(self[k], numpy.ndarray) or isinstance(other[k], numpy.ndarray):
if not numpy.allclose(self[k], other[k]):
return False
elif self.data[k] != other.data[k]:
elif self[k] != other[k]:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
class Wrapper(Type):
"""
......@@ -191,10 +184,8 @@ class Wrapper(Type):
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
if isinstance(data, dict):
if strict:
raise TypeError('%s: strict mode: data should be an object, not a dict.' % self)
data = Wrap(**data)
if strict and not isinstance(data, Wrap):
raise TypeError('%s: strict mode: data should be an instance of Wrap.' % self)
return self.check_that_values_are_compatible(data, strict, allow_downcast)
def values_eq(self, a, b):
......@@ -384,7 +375,7 @@ class Wrapper(Type):
""" % locals()
def c_code_cache_version(self):
return (1, 3)
return (1, 4)
def c_declare(self, name, sub, check_input=True):
struct_name = self.name
......@@ -412,7 +403,7 @@ class Wrapper(Type):
%(fail)s
}
for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyObject_GetAttrString(py_%(name)s, fields[i]);
PyObject* o = PyDict_GetItemString(py_%(name)s, fields[i]);
if (o == NULL) {
PyErr_Format(PyExc_TypeError, "Wrapper: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论