提交 997abb48 authored 作者: nouiz's avatar nouiz

Merge pull request #1100 from delallea/determinism

Fix for a source of non-determinism
......@@ -167,10 +167,11 @@ def scan(fn,
# ^ initial state but taps not provided
if 'taps' in outs_info[i]:
# ^ explicitly provided a None for taps
_logger.warning('Output %s ( index %d) has a memory '
'buffer but taps is explicitly set to None ',
getattr(outs_info[i]['membuf'], 'name', 'None'),
i)
_logger.warning(
'Output %s (index %d) has a memory '
'buffer but taps is explicitly set to None ',
getattr(outs_info[i]['membuf'], 'name', 'None'),
i)
outs_info[i]['taps'] = [-1]
else:
# if a None is provided as the output info we replace it
......@@ -213,7 +214,7 @@ def scan(fn,
nw_slice = _seq_val_slice.type()
if seq.name:
nw_slice.name=seq.name + '[t]'
nw_slice.name = seq.name + '[t]'
scan_seqs.append(_seq_val)
inner_seqs.append(nw_slice)
inner_slices.append(actual_slice)
......@@ -354,7 +355,6 @@ def scan(fn,
else:
pass
# Re-order args
max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1
max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1
......@@ -401,7 +401,8 @@ def scan(fn,
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
lambda_result = fn(*args)
condition, outputs, updates = scan_utils.get_updates_and_outputs(lambda_result)
condition, outputs, updates = scan_utils.get_updates_and_outputs(
lambda_result)
if condition is not None:
as_while = True
else:
......@@ -685,5 +686,5 @@ def scan(fn,
elif len(scan_out_list) == 0:
scan_out_list = None
assert isinstance(update_map, dict) and 'Ordered' in str(type(update_map))
assert isinstance(update_map, OrderedDict)
return (scan_out_list, update_map)
......@@ -16,7 +16,8 @@ logger = logging.getLogger('theano.updates')
import warnings
# Must be an OrderedDict or updates will be applied in a non-deterministic order
# Must be an OrderedDict or updates will be applied in a non-deterministic
# order.
class OrderedUpdates(OrderedDict):
"""
Dict-like mapping from SharedVariable keys to their new values.
......@@ -24,13 +25,20 @@ class OrderedUpdates(OrderedDict):
This mapping supports the use of the "+" operator for the union of updates.
"""
def __init__(self, *key, **kwargs):
ret = super(OrderedUpdates, self).__init__(*key, **kwargs)
if (len(key) >= 1 and
isinstance(key[0], dict) and
len(key[0]) > 1 and
not isinstance(key[0], OrderedDict)):
# Warn when using as input a non-ordered dictionary.
warnings.warn('Initializing an `OrderedUpdates` from a '
'non-ordered dictionary with 2+ elements could '
'make your code non-deterministic')
super(OrderedUpdates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'OrderedUpdates keys must inherit from SharedVariable',
key)
return ret
def __setitem__(self, key, value):
if isinstance(key, SharedVariable):
......@@ -44,13 +52,20 @@ class OrderedUpdates(OrderedDict):
return super(OrderedUpdates, self).__setitem__(key, value)
else:
raise TypeError('OrderedUpdates keys must inherit from SharedVariable',
key)
raise TypeError('OrderedUpdates keys must inherit from '
'SharedVariable', key)
def update(self, other=None):
if other is None:
return
for key, val in dict(other).iteritems():
if (isinstance(other, dict) and
len(other) > 1 and
not isinstance(other, OrderedDict)):
# Warn about non-determinism.
warnings.warn('Updating an `OrderedUpdates` with a '
'non-ordered dictionary with 2+ elements could '
'make your code non-deterministic')
for key, val in OrderedDict(other).iteritems():
if key in self:
if self[key] == val:
continue
......@@ -69,6 +84,7 @@ class OrderedUpdates(OrderedDict):
rval.update(self)
return rval
def Updates(*key, **kwargs):
warnings.warn("Updates is deprecated. Switch to OrderedUpdates.")
return OrderedUpdates(*key, **kwargs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论