提交 ea99dcb4 authored 作者: Frederic's avatar Frederic

make an utility function to hash dict/list for inplace pattern.

上级 3d69604e
......@@ -13,6 +13,8 @@ from theano import scalar
from theano.scalar import Scalar
from theano.printing import min_informative_str, pprint
from theano.gof.python25 import all, any
from theano.tensor.utils import hash_from_dict
config = theano.config
......@@ -563,17 +565,8 @@ class Elemwise(Op):
return False
def _rehash(self):
items = self.inplace_pattern.items()
items.sort()
first_part = [k for k, v in items]
second_part = []
for k, v in items:
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part)
h = hash('Elemwise') ^ hash(self.scalar_op) ^ hash(tuple_items)
inplace_pattern_hash = hash_from_dict(self.inplace_pattern)
h = hash('Elemwise') ^ hash(self.scalar_op) ^ inplace_pattern_hash
assert h == getattr(self, '_hashval', h)
self._hashval = h
......
import numpy
from theano.tensor.utils import hash_from_ndarray
from theano.tensor.utils import hash_from_ndarray, hash_from_dict
def test_hash_from_ndarray():
......@@ -31,3 +31,18 @@ def test_hash_from_ndarray():
assert hash_from_ndarray(rng[:4]) == hash_from_ndarray(rng[:4].copy())
assert hash_from_ndarray(rng[::2]) == hash_from_ndarray(rng[::2].copy())
assert hash_from_ndarray(rng[::-1]) == hash_from_ndarray(rng[::-1].copy())
def test_hash_from_dict():
dicts = [{}, {0: 0}, {0: 1}, {1: 0}, {1: 1},
{0: (0,)}, {0: [1]},
{0: (0, 1)}, {0: [1, 0]},
]
hashs = []
for idx, d in enumerate(dicts):
h = hash_from_dict(d)
assert h not in hashs
hashs.append(h)
# List are not hashable. So they are transformed into tuple.
assert hash_from_dict({0: (0,)}) == hash_from_dict({0: [0]})
......@@ -18,3 +18,28 @@ def hash_from_ndarray(data):
hash_from_code(str(data.shape)) +
hash_from_code(str(data.strides)) +
hash_from_code(str(data.dtype)))
def hash_from_dict(d):
"""Work around the fact that dict are not hashable in python
This request that all object have a sorted order that depend only
on the value of the object. This is true for integer/float/string
We do not verify that the objects in the dict what this properties
Also, we transform values that are list into tuple as list are not
hashable.
"""
items = d.items()
items.sort()
first_part = [k for k, v in items]
second_part = []
for k, v in items:
if isinstance(v, (tuple, list)):
second_part += [tuple(v)]
else:
second_part += [v]
tuple_items = tuple(first_part + second_part)
return hash(tuple_items)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论