提交 8779c7eb authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed scan hashing issue

上级 18b9b449
...@@ -13,29 +13,24 @@ def info(*msg): ...@@ -13,29 +13,24 @@ def info(*msg):
_logger.info('INFO theano.scan: '+' '.join(msg)) _logger.info('INFO theano.scan: '+' '.join(msg))
# Hashing a list; list used by scan are list of numbers, therefore a list # Hashing a dictionary or a list or a tuple or any type that is hashable with
# can be hashed by hashing all elements in the list # the hash() function
def hash_list(list): def hash_listsDictsTuples(x):
hash_value = 0 hash_value = 0
for v in list: if type(x) == dict :
hash_value ^= hash(v) for k,v in x.iteritems():
return hash_value hash_value ^= hash_listsDictsTuples(k)
hash_value ^= hash_listsDictsTuples(v)
elif type(x) in (list,tuple):
# Hashing a dictionary; the dictionary used by scan has as keys numbers and for v in x:
# as values either numbers or list of numbers hash_value ^= hash_listsDictsTuples(v)
def hash_dict(dictionary): else:
hash_value = 0 try:
for k,v in dictionary.iteritems(): hash_value ^= hash(x)
# hash key except:
hash_value ^= hash(k) pass
if type(v) in (list,tuple):
hash_value ^= hash_list(v)
else:
hash_value ^= hash(v)
return hash_value return hash_value
def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
sequences_taps={}, outputs_taps = {}, sequences_taps={}, outputs_taps = {},
n_steps = theano.tensor.zero(), force_gradient = False, n_steps = theano.tensor.zero(), force_gradient = False,
...@@ -268,13 +263,13 @@ class Scan(theano.Op): ...@@ -268,13 +263,13 @@ class Scan(theano.Op):
hash(self.go_backwards) ^\ hash(self.go_backwards) ^\
hash(self.truncate_gradient) ^\ hash(self.truncate_gradient) ^\
hash(self.n_args) ^ \ hash(self.n_args) ^ \
hash_list(self.outputs) ^ \ hash_listsDictsTuples(self.outputs) ^ \
hash_list(self.inputs) ^ \ hash_listsDictsTuples(self.inputs) ^ \
hash_list(self.g_ins) ^ \ hash_listsDictsTuples(self.g_ins) ^ \
hash_list(self.g_outs) ^ \ hash_listsDictsTuples(self.g_outs) ^ \
hash_dict(self.seqs_taps) ^\ hash_listsDictsTuples(self.seqs_taps) ^\
hash_dict(self.outs_taps) ^\ hash_listsDictsTuples(self.outs_taps) ^\
hash_dict(self.updates) hash_listsDictsTuples(self.updates)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论