提交 3c440a74 authored 作者: notoraptor's avatar notoraptor

Secure code.

Optimize Wrap.__hash__.
上级 c81517aa
......@@ -801,7 +801,8 @@ class Op(utils.object2, PureOp, CLinkerOp):
def get_params(self, node):
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper):
wrapper = self.params_type
if all(hasattr(self, field) for field in wrapper.fields):
if not all(hasattr(self, field) for field in wrapper.fields):
raise AttributeError('%s: missing attributes for Wrapper parameter.' % type(self).__name__)
wrap_dict = dict()
for i in range(wrapper.length):
field = wrapper.fields[i]
......
......@@ -101,7 +101,7 @@ class Wrap(dict):
if field not in kwargs:
raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field)
super(Wrap, self).__init__(**kwargs)
self.__dict__.update(wrapper=wrapper)
self.__dict__.update(wrapper=wrapper, signatures=None)
def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())])
......@@ -121,11 +121,15 @@ class Wrap(dict):
raise NotImplementedError('Wrap is immutable')
def __hash__(self):
return hash((type(self), self.wrapper) + tuple(
# As values are immutable, we can save data signatures the first time
# to not regenerate them in future hash() calls.
if self.__dict__['signatures'] is None:
self.__dict__['signatures'] = tuple(
# NB: Wrapped data should have been already filtered.
self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature()
for i in range(self.wrapper.length)
))
)
return hash((type(self), self.wrapper) + self.signatures)
def __eq__(self, other):
return (type(self) == type(other) and self.wrapper == other.wrapper and all(
......@@ -210,7 +214,7 @@ class Wrapper(Type):
wrap_instance = dict()
for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return data if strict else Wrap(self, **wrap_instance)
return data if (strict or isinstance(data, Wrap)) else Wrap(self, **wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论