提交 997abb48 authored 作者: nouiz's avatar nouiz

Merge pull request #1100 from delallea/determinism

Fix for a source of non-determinism
...@@ -167,7 +167,8 @@ def scan(fn, ...@@ -167,7 +167,8 @@ def scan(fn,
# ^ initial state but taps not provided # ^ initial state but taps not provided
if 'taps' in outs_info[i]: if 'taps' in outs_info[i]:
# ^ explicitly provided a None for taps # ^ explicitly provided a None for taps
_logger.warning('Output %s ( index %d) has a memory ' _logger.warning(
'Output %s (index %d) has a memory '
'buffer but taps is explicitly set to None ', 'buffer but taps is explicitly set to None ',
getattr(outs_info[i]['membuf'], 'name', 'None'), getattr(outs_info[i]['membuf'], 'name', 'None'),
i) i)
...@@ -213,7 +214,7 @@ def scan(fn, ...@@ -213,7 +214,7 @@ def scan(fn,
nw_slice = _seq_val_slice.type() nw_slice = _seq_val_slice.type()
if seq.name: if seq.name:
nw_slice.name=seq.name + '[t]' nw_slice.name = seq.name + '[t]'
scan_seqs.append(_seq_val) scan_seqs.append(_seq_val)
inner_seqs.append(nw_slice) inner_seqs.append(nw_slice)
inner_slices.append(actual_slice) inner_slices.append(actual_slice)
...@@ -354,7 +355,6 @@ def scan(fn, ...@@ -354,7 +355,6 @@ def scan(fn,
else: else:
pass pass
# Re-order args # Re-order args
max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1 max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1
max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1 max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1
...@@ -401,7 +401,8 @@ def scan(fn, ...@@ -401,7 +401,8 @@ def scan(fn,
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
lambda_result = fn(*args) lambda_result = fn(*args)
condition, outputs, updates = scan_utils.get_updates_and_outputs(lambda_result) condition, outputs, updates = scan_utils.get_updates_and_outputs(
lambda_result)
if condition is not None: if condition is not None:
as_while = True as_while = True
else: else:
...@@ -685,5 +686,5 @@ def scan(fn, ...@@ -685,5 +686,5 @@ def scan(fn,
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
assert isinstance(update_map, dict) and 'Ordered' in str(type(update_map)) assert isinstance(update_map, OrderedDict)
return (scan_out_list, update_map) return (scan_out_list, update_map)
...@@ -16,7 +16,8 @@ logger = logging.getLogger('theano.updates') ...@@ -16,7 +16,8 @@ logger = logging.getLogger('theano.updates')
import warnings import warnings
# Must be an OrderedDict or updates will be applied in a non-deterministic order # Must be an OrderedDict or updates will be applied in a non-deterministic
# order.
class OrderedUpdates(OrderedDict): class OrderedUpdates(OrderedDict):
""" """
Dict-like mapping from SharedVariable keys to their new values. Dict-like mapping from SharedVariable keys to their new values.
...@@ -24,13 +25,20 @@ class OrderedUpdates(OrderedDict): ...@@ -24,13 +25,20 @@ class OrderedUpdates(OrderedDict):
This mapping supports the use of the "+" operator for the union of updates. This mapping supports the use of the "+" operator for the union of updates.
""" """
def __init__(self, *key, **kwargs): def __init__(self, *key, **kwargs):
ret = super(OrderedUpdates, self).__init__(*key, **kwargs) if (len(key) >= 1 and
isinstance(key[0], dict) and
len(key[0]) > 1 and
not isinstance(key[0], OrderedDict)):
# Warn when using as input a non-ordered dictionary.
warnings.warn('Initializing an `OrderedUpdates` from a '
'non-ordered dictionary with 2+ elements could '
'make your code non-deterministic')
super(OrderedUpdates, self).__init__(*key, **kwargs)
for key in self: for key in self:
if not isinstance(key, SharedVariable): if not isinstance(key, SharedVariable):
raise TypeError( raise TypeError(
'OrderedUpdates keys must inherit from SharedVariable', 'OrderedUpdates keys must inherit from SharedVariable',
key) key)
return ret
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, SharedVariable): if isinstance(key, SharedVariable):
...@@ -44,13 +52,20 @@ class OrderedUpdates(OrderedDict): ...@@ -44,13 +52,20 @@ class OrderedUpdates(OrderedDict):
return super(OrderedUpdates, self).__setitem__(key, value) return super(OrderedUpdates, self).__setitem__(key, value)
else: else:
raise TypeError('OrderedUpdates keys must inherit from SharedVariable', raise TypeError('OrderedUpdates keys must inherit from '
key) 'SharedVariable', key)
def update(self, other=None): def update(self, other=None):
if other is None: if other is None:
return return
for key, val in dict(other).iteritems(): if (isinstance(other, dict) and
len(other) > 1 and
not isinstance(other, OrderedDict)):
# Warn about non-determinism.
warnings.warn('Updating an `OrderedUpdates` with a '
'non-ordered dictionary with 2+ elements could '
'make your code non-deterministic')
for key, val in OrderedDict(other).iteritems():
if key in self: if key in self:
if self[key] == val: if self[key] == val:
continue continue
...@@ -69,6 +84,7 @@ class OrderedUpdates(OrderedDict): ...@@ -69,6 +84,7 @@ class OrderedUpdates(OrderedDict):
rval.update(self) rval.update(self)
return rval return rval
def Updates(*key, **kwargs): def Updates(*key, **kwargs):
warnings.warn("Updates is deprecated. Switch to OrderedUpdates.") warnings.warn("Updates is deprecated. Switch to OrderedUpdates.")
return OrderedUpdates(*key, **kwargs) return OrderedUpdates(*key, **kwargs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论