提交 a5ce2a09 authored 作者: sentient07's avatar sentient07

Changing dict to frozen dict

上级 42750428
import collections
import operator
import functools
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping`
interface. It can be used as a drop-in replacement for dictionaries where immutability and ordering are desired.
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
hashes = map(hash, self.items())
self._hash = functools.reduce(operator.xor, hashes, 0)
return self._hash
class FrozenOrderedDict(frozendict):
"""
A FrozenDict subclass that maintains key order
"""
dict_cls = collections.OrderedDict
...@@ -16,9 +16,8 @@ from theano.scalar import get_scalar_type ...@@ -16,9 +16,8 @@ from theano.scalar import get_scalar_type
from theano.printing import pprint from theano.printing import pprint
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType from theano.gof.null_type import NullType
from theano.gof.utils import hash_from_dict
from theano.tensor import elemwise_cgen as cgen from theano.tensor import elemwise_cgen as cgen
from theano.misc.frozendict import frozendict
config = theano.config config = theano.config
...@@ -472,14 +471,16 @@ second dimension ...@@ -472,14 +471,16 @@ second dimension
""" """
__props__ = ("scalar_op", "inplace_pattern", "name", "nfunc_spec", "openmp")
def __init__(self, scalar_op, inplace_pattern=None, name=None, def __init__(self, scalar_op, inplace_pattern=None, name=None,
nfunc_spec=None, openmp=None): nfunc_spec=None, openmp=None):
if inplace_pattern is None: if inplace_pattern is None:
inplace_pattern = {} inplace_pattern = frozendict({})
self.name = name self.name = name
self.scalar_op = scalar_op self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern self.inplace_pattern = frozendict(inplace_pattern)
self.destroy_map = dict((o, [i]) for o, i in inplace_pattern.items()) self.destroy_map = dict((o, [i]) for o, i in frozendict(inplace_pattern).items())
self.ufunc = None self.ufunc = None
self.nfunc = None self.nfunc = None
...@@ -489,8 +490,6 @@ second dimension ...@@ -489,8 +490,6 @@ second dimension
if nfunc_spec: if nfunc_spec:
self.nfunc = getattr(numpy, nfunc_spec[0]) self.nfunc = getattr(numpy, nfunc_spec[0])
# precompute the hash of this node
self._rehash()
super(Elemwise, self).__init__(openmp=openmp) super(Elemwise, self).__init__(openmp=openmp)
def __getstate__(self): def __getstate__(self):
...@@ -498,7 +497,6 @@ second dimension ...@@ -498,7 +497,6 @@ second dimension
d.pop('ufunc') d.pop('ufunc')
d.pop('nfunc') d.pop('nfunc')
d.pop('__epydoc_asRoutine', None) d.pop('__epydoc_asRoutine', None)
d.pop('_hashval')
return d return d
def __setstate__(self, d): def __setstate__(self, d):
...@@ -511,7 +509,6 @@ second dimension ...@@ -511,7 +509,6 @@ second dimension
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.ufunc = numpy.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin, self.scalar_op.nin,
self.scalar_op.nout) self.scalar_op.nout)
self._rehash()
def get_output_info(self, dim_shuffle, *inputs): def get_output_info(self, dim_shuffle, *inputs):
"""Return the outputs dtype and broadcastable pattern and the """Return the outputs dtype and broadcastable pattern and the
...@@ -584,37 +581,6 @@ second dimension ...@@ -584,37 +581,6 @@ second dimension
out_broadcastables)] out_broadcastables)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def __eq__(self, other):
if type(self) == type(other):
items = list(self.inplace_pattern.items())
other_items = list(other.inplace_pattern.items())
items.sort()
other_items.sort()
rval = ((self.scalar_op == other.scalar_op) and
(items == other_items))
return rval
return False
def _rehash(self):
inplace_pattern_hash = hash_from_dict(self.inplace_pattern)
h = hash('Elemwise') ^ hash(self.scalar_op) ^ inplace_pattern_hash
assert h == getattr(self, '_hashval', h)
self._hashval = h
def __hash__(self):
return self._hashval
def __str__(self):
if self.name is None:
if self.inplace_pattern:
items = list(self.inplace_pattern.items())
items.sort()
return "Elemwise{%s}%s" % (self.scalar_op, str(items))
else:
return "Elemwise{%s}" % (self.scalar_op)
else:
return self.name
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
outs = self(*inputs, **dict(return_list=True)) outs = self(*inputs, **dict(return_list=True))
rval = [None for x in outs] rval = [None for x in outs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论