提交 3c440a74 authored 作者: notoraptor's avatar notoraptor

Secure code.

Optimize Wrap.__hash__.
上级 c81517aa
...@@ -801,7 +801,8 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -801,7 +801,8 @@ class Op(utils.object2, PureOp, CLinkerOp):
def get_params(self, node): def get_params(self, node):
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper): if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper):
wrapper = self.params_type wrapper = self.params_type
if all(hasattr(self, field) for field in wrapper.fields): if not all(hasattr(self, field) for field in wrapper.fields):
raise AttributeError('%s: missing attributes for Wrapper parameter.' % type(self).__name__)
wrap_dict = dict() wrap_dict = dict()
for i in range(wrapper.length): for i in range(wrapper.length):
field = wrapper.fields[i] field = wrapper.fields[i]
......
...@@ -101,7 +101,7 @@ class Wrap(dict): ...@@ -101,7 +101,7 @@ class Wrap(dict):
if field not in kwargs: if field not in kwargs:
raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field) raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field)
super(Wrap, self).__init__(**kwargs) super(Wrap, self).__init__(**kwargs)
self.__dict__.update(wrapper=wrapper) self.__dict__.update(wrapper=wrapper, signatures=None)
def __repr__(self): def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())]) return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())])
...@@ -121,11 +121,15 @@ class Wrap(dict): ...@@ -121,11 +121,15 @@ class Wrap(dict):
raise NotImplementedError('Wrap is immutable') raise NotImplementedError('Wrap is immutable')
def __hash__(self): def __hash__(self):
return hash((type(self), self.wrapper) + tuple( # As values are immutable, we can save data signatures the first time
# to not regenerate them in future hash() calls.
if self.__dict__['signatures'] is None:
self.__dict__['signatures'] = tuple(
# NB: Wrapped data should have been already filtered. # NB: Wrapped data should have been already filtered.
self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature() self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature()
for i in range(self.wrapper.length) for i in range(self.wrapper.length)
)) )
return hash((type(self), self.wrapper) + self.signatures)
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self.wrapper == other.wrapper and all( return (type(self) == type(other) and self.wrapper == other.wrapper and all(
...@@ -210,7 +214,7 @@ class Wrapper(Type): ...@@ -210,7 +214,7 @@ class Wrapper(Type):
wrap_instance = dict() wrap_instance = dict()
for i in range(self.length): for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast) wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return data if strict else Wrap(self, **wrap_instance) return data if (strict or isinstance(data, Wrap)) else Wrap(self, **wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes. # Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论