提交 8779c7eb authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed scan hashing issue

上级 18b9b449
......@@ -13,29 +13,24 @@ def info(*msg):
_logger.info('INFO theano.scan: '+' '.join(msg))
# Hashing a list; list used by scan are list of numbers, therefore a list
# can be hashed by hashing all elements in the list
def hash_list(list):
# Hashing a dictionary or a list or a tuple or any type that is hashable with
# the hash() function
def hash_listsDictsTuples(x):
hash_value = 0
for v in list:
hash_value ^= hash(v)
return hash_value
# Hashing a dictionary; the dictionary used by scan has as keys numbers and
# as values either numbers or list of numbers
def hash_dict(dictionary):
hash_value = 0
for k,v in dictionary.iteritems():
# hash key
hash_value ^= hash(k)
if type(v) in (list,tuple):
hash_value ^= hash_list(v)
else:
hash_value ^= hash(v)
if type(x) == dict :
for k,v in x.iteritems():
hash_value ^= hash_listsDictsTuples(k)
hash_value ^= hash_listsDictsTuples(v)
elif type(x) in (list,tuple):
for v in x:
hash_value ^= hash_listsDictsTuples(v)
else:
try:
hash_value ^= hash(x)
except:
pass
return hash_value
def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
sequences_taps={}, outputs_taps = {},
n_steps = theano.tensor.zero(), force_gradient = False,
......@@ -268,13 +263,13 @@ class Scan(theano.Op):
hash(self.go_backwards) ^\
hash(self.truncate_gradient) ^\
hash(self.n_args) ^ \
hash_list(self.outputs) ^ \
hash_list(self.inputs) ^ \
hash_list(self.g_ins) ^ \
hash_list(self.g_outs) ^ \
hash_dict(self.seqs_taps) ^\
hash_dict(self.outs_taps) ^\
hash_dict(self.updates)
hash_listsDictsTuples(self.outputs) ^ \
hash_listsDictsTuples(self.inputs) ^ \
hash_listsDictsTuples(self.g_ins) ^ \
hash_listsDictsTuples(self.g_outs) ^ \
hash_listsDictsTuples(self.seqs_taps) ^\
hash_listsDictsTuples(self.outs_taps) ^\
hash_listsDictsTuples(self.updates)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论