提交 3c440a74 authored 作者: notoraptor's avatar notoraptor

Secure code.

Optimize Wrap.__hash__.
上级 c81517aa
...@@ -801,13 +801,14 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -801,13 +801,14 @@ class Op(utils.object2, PureOp, CLinkerOp):
def get_params(self, node): def get_params(self, node):
if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper): if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.Wrapper):
wrapper = self.params_type wrapper = self.params_type
if all(hasattr(self, field) for field in wrapper.fields): if not all(hasattr(self, field) for field in wrapper.fields):
wrap_dict = dict() raise AttributeError('%s: missing attributes for Wrapper parameter.' % type(self).__name__)
for i in range(wrapper.length): wrap_dict = dict()
field = wrapper.fields[i] for i in range(wrapper.length):
_type = wrapper.types[i] field = wrapper.fields[i]
wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True) _type = wrapper.types[i]
return theano.gof.Wrap(wrapper, **wrap_dict) wrap_dict[field] = _type.filter(getattr(self, field), strict=False, allow_downcast=True)
return theano.gof.Wrap(wrapper, **wrap_dict)
raise theano.gof.utils.MethodNotDefined('get_params') raise theano.gof.utils.MethodNotDefined('get_params')
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
......
...@@ -101,7 +101,7 @@ class Wrap(dict): ...@@ -101,7 +101,7 @@ class Wrap(dict):
if field not in kwargs: if field not in kwargs:
raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field) raise TypeError('Wrap: Wrapper attribute "%s" not in Wrap args.' % field)
super(Wrap, self).__init__(**kwargs) super(Wrap, self).__init__(**kwargs)
self.__dict__.update(wrapper=wrapper) self.__dict__.update(wrapper=wrapper, signatures=None)
def __repr__(self): def __repr__(self):
return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())]) return 'Wrap(%s)' % ', '.join([('%s:%s' % (k, type(self[k]))) for k in sorted(self.keys())])
...@@ -121,11 +121,15 @@ class Wrap(dict): ...@@ -121,11 +121,15 @@ class Wrap(dict):
raise NotImplementedError('Wrap is immutable') raise NotImplementedError('Wrap is immutable')
def __hash__(self): def __hash__(self):
return hash((type(self), self.wrapper) + tuple( # As values are immutable, we can save data signatures the first time
# NB: Wrapped data should have been already filtered. # to not regenerate them in future hash() calls.
self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature() if self.__dict__['signatures'] is None:
for i in range(self.wrapper.length) self.__dict__['signatures'] = tuple(
)) # NB: Wrapped data should have been already filtered.
self.wrapper.types[i].make_constant(self[self.wrapper.fields[i]]).signature()
for i in range(self.wrapper.length)
)
return hash((type(self), self.wrapper) + self.signatures)
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self.wrapper == other.wrapper and all( return (type(self) == type(other) and self.wrapper == other.wrapper and all(
...@@ -210,7 +214,7 @@ class Wrapper(Type): ...@@ -210,7 +214,7 @@ class Wrapper(Type):
wrap_instance = dict() wrap_instance = dict()
for i in range(self.length): for i in range(self.length):
wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast) wrap_instance[self.fields[i]] = self.types[i].filter(getattr(data, self.fields[i]), strict, allow_downcast)
return data if strict else Wrap(self, **wrap_instance) return data if (strict or isinstance(data, Wrap)) else Wrap(self, **wrap_instance)
# Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes. # Returns a wrapped object with expected attributes or (in strict mode) checks that data has expected attributes.
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论