提交 6447ce7e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed determinism issues in scan

上级 d6d377d6
...@@ -88,7 +88,7 @@ from printing import \ ...@@ -88,7 +88,7 @@ from printing import \
import scan_module import scan_module
from scan_module import scan, map, reduce, foldl, foldr, clone from scan_module import scan, map, reduce, foldl, foldr, clone
from updates import Updates from updates import OrderedUpdates
import tensor import tensor
import scalar import scalar
......
...@@ -17,10 +17,11 @@ import numpy ...@@ -17,10 +17,11 @@ import numpy
from theano.compile import SharedVariable, function from theano.compile import SharedVariable, function
from theano import compile from theano import compile
from theano import gof from theano import gof
from theano.gof.python25 import OrderedDict
from theano.tensor import opt from theano.tensor import opt
from theano import tensor from theano import tensor
from theano import config from theano import config
from theano.updates import Updates from theano.updates import OrderedUpdates
from theano.scan_module import scan_op from theano.scan_module import scan_op
...@@ -147,7 +148,7 @@ def scan(fn, ...@@ -147,7 +148,7 @@ def scan(fn,
n_seqs = len(seqs) n_seqs = len(seqs)
n_outs = len(outs_info) n_outs = len(outs_info)
return_steps = {} return_steps = OrderedDict()
# wrap outputs info in a dictionary if they are not already in one # wrap outputs info in a dictionary if they are not already in one
for i in xrange(n_outs): for i in xrange(n_outs):
if outs_info[i] is not None: if outs_info[i] is not None:
...@@ -242,7 +243,7 @@ def scan(fn, ...@@ -242,7 +243,7 @@ def scan(fn,
mit_sot_inner_inputs = [] mit_sot_inner_inputs = []
mit_sot_inner_slices = [] mit_sot_inner_slices = []
mit_sot_inner_outputs = [] mit_sot_inner_outputs = []
mit_sot_return_steps = {} mit_sot_return_steps = OrderedDict()
mit_sot_tap_array = [] mit_sot_tap_array = []
mit_sot_rightOrder = [] mit_sot_rightOrder = []
...@@ -251,7 +252,7 @@ def scan(fn, ...@@ -251,7 +252,7 @@ def scan(fn,
sit_sot_inner_inputs = [] sit_sot_inner_inputs = []
sit_sot_inner_slices = [] sit_sot_inner_slices = []
sit_sot_inner_outputs = [] sit_sot_inner_outputs = []
sit_sot_return_steps = {} sit_sot_return_steps = OrderedDict()
sit_sot_rightOrder = [] sit_sot_rightOrder = []
nit_sot_steps = [] nit_sot_steps = []
# go through outputs picking up time slices as needed # go through outputs picking up time slices as needed
...@@ -398,7 +399,8 @@ def scan(fn, ...@@ -398,7 +399,8 @@ def scan(fn,
not isinstance(arg, tensor.Constant))] not isinstance(arg, tensor.Constant))]
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args)) lambda_result = fn(*args)
condition, outputs, updates = scan_utils.get_updates_and_outputs(lambda_result)
if condition is not None: if condition is not None:
as_while = True as_while = True
else: else:
...@@ -464,6 +466,11 @@ def scan(fn, ...@@ -464,6 +466,11 @@ def scan(fn,
dummy_outs = outputs dummy_outs = outputs
if condition is not None: if condition is not None:
dummy_outs.append(condition) dummy_outs.append(condition)
# If we use a regular dict here, the results are non-deterministic
assert isinstance(updates, (list, tuple)) or (isinstance(updates, dict) and \
'Ordered' in str(type(updates)))
dummy_f = function(dummy_args, dummy_f = function(dummy_args,
dummy_outs, dummy_outs,
updates=updates, updates=updates,
...@@ -508,7 +515,7 @@ def scan(fn, ...@@ -508,7 +515,7 @@ def scan(fn,
sit_sot_inner_outputs.append(outputs[i]) sit_sot_inner_outputs.append(outputs[i])
## Step 5.3 Outputs that correspond to update rules of shared variables ## Step 5.3 Outputs that correspond to update rules of shared variables
givens = {} givens = OrderedDict()
n_shared_outs = 0 n_shared_outs = 0
shared_scan_inputs = [] shared_scan_inputs = []
shared_inner_inputs = [] shared_inner_inputs = []
...@@ -527,7 +534,7 @@ def scan(fn, ...@@ -527,7 +534,7 @@ def scan(fn,
## Step 5.4 Outputs with no taps used in the input ## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0 n_nit_sot = 0
nit_sot_inner_outputs = [] nit_sot_inner_outputs = []
nit_sot_return_steps = {} nit_sot_return_steps = OrderedDict()
nit_sot_rightOrder = [] nit_sot_rightOrder = []
for i, out in enumerate(outs_info): for i, out in enumerate(outs_info):
if not 'taps' in out: if not 'taps' in out:
...@@ -582,7 +589,7 @@ def scan(fn, ...@@ -582,7 +589,7 @@ def scan(fn,
shared_inner_outputs) shared_inner_outputs)
if condition is not None: if condition is not None:
inner_outs.append(condition) inner_outs.append(condition)
new_givens = {} new_givens = OrderedDict()
for w, w_copy in givens.iteritems(): for w, w_copy in givens.iteritems():
new_givens[w] = w.type.filter_variable(w_copy) new_givens[w] = w.type.filter_variable(w_copy)
...@@ -593,7 +600,7 @@ def scan(fn, ...@@ -593,7 +600,7 @@ def scan(fn,
## ##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
info = {} info = OrderedDict()
info['tap_array'] = tap_array info['tap_array'] = tap_array
info['n_seqs'] = n_seqs info['n_seqs'] = n_seqs
...@@ -607,7 +614,7 @@ def scan(fn, ...@@ -607,7 +614,7 @@ def scan(fn,
info['truncate_gradient'] = -1 info['truncate_gradient'] = -1
info['name'] = name info['name'] = name
info['mode'] = mode info['mode'] = mode
info['destroy_map'] = {} info['destroy_map'] = OrderedDict()
info['inplace'] = False info['inplace'] = False
info['gpu'] = False info['gpu'] = False
info['as_while'] = as_while info['as_while'] = as_while
...@@ -641,7 +648,7 @@ def scan(fn, ...@@ -641,7 +648,7 @@ def scan(fn,
### and so on ... ### and so on ...
## ##
update_map = Updates() update_map = OrderedUpdates()
offset = n_mit_mot offset = n_mit_mot
offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array] offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array]
...@@ -675,4 +682,5 @@ def scan(fn, ...@@ -675,4 +682,5 @@ def scan(fn,
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
assert isinstance(update_map, dict) and 'Ordered' in str(type(update_map))
return (scan_out_list, update_map) return (scan_out_list, update_map)
...@@ -52,8 +52,9 @@ from theano import gof ...@@ -52,8 +52,9 @@ from theano import gof
from theano.tensor import opt from theano.tensor import opt
from theano import tensor from theano import tensor
from theano import config from theano import config
from theano.updates import Updates from theano.updates import OrderedUpdates
from theano.compile import ops from theano.compile import ops
from theano.gof.python25 import OrderedDict
import scan_op import scan_op
...@@ -376,11 +377,11 @@ def scan(fn, ...@@ -376,11 +377,11 @@ def scan(fn,
n_seqs = len(seqs) n_seqs = len(seqs)
n_outs = len(outs_info) n_outs = len(outs_info)
return_steps = {} return_steps = OrderedDict()
# wrap sequences in a dictionary if they are not already dictionaries # wrap sequences in a dictionary if they are not already dictionaries
for i in xrange(n_seqs): for i in xrange(n_seqs):
if not isinstance(seqs[i], dict): if not isinstance(seqs[i], dict):
seqs[i] = dict(input=seqs[i], taps=[0]) seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])])
elif seqs[i].get('taps', None): elif seqs[i].get('taps', None):
seqs[i]['taps'] = wrap_into_list(seqs[i]['taps']) seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
elif seqs[i].get('taps', True) is None: elif seqs[i].get('taps', True) is None:
...@@ -402,7 +403,7 @@ def scan(fn, ...@@ -402,7 +403,7 @@ def scan(fn,
if not isinstance(outs_info[i], dict): if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1 # by default any output has a tap value of -1
outs_info[i] = dict(initial=outs_info[i], taps=[-1]) outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])])
elif (not outs_info[i].get('initial', None) and elif (not outs_info[i].get('initial', None) and
outs_info[i].get('taps', None)): outs_info[i].get('taps', None)):
# ^ no initial state but taps provided # ^ no initial state but taps provided
...@@ -421,8 +422,8 @@ def scan(fn, ...@@ -421,8 +422,8 @@ def scan(fn,
outs_info[i]['taps'] = [-1] outs_info[i]['taps'] = [-1]
else: else:
# if a None is provided as the output info we replace it # if a None is provided as the output info we replace it
# with an empty dict() to simplify handling # with an empty OrdereDict() to simplify handling
outs_info[i] = dict() outs_info[i] = OrderedDict()
## ##
### Step 2. Generate inputs and outputs of the inner functions ### Step 2. Generate inputs and outputs of the inner functions
...@@ -565,7 +566,7 @@ def scan(fn, ...@@ -565,7 +566,7 @@ def scan(fn,
mit_sot_inner_inputs = [] mit_sot_inner_inputs = []
mit_sot_inner_slices = [] mit_sot_inner_slices = []
mit_sot_inner_outputs = [] mit_sot_inner_outputs = []
mit_sot_return_steps = {} mit_sot_return_steps = OrderedDict()
mit_sot_tap_array = [] mit_sot_tap_array = []
mit_sot_rightOrder = [] mit_sot_rightOrder = []
...@@ -574,7 +575,7 @@ def scan(fn, ...@@ -574,7 +575,7 @@ def scan(fn,
sit_sot_inner_inputs = [] sit_sot_inner_inputs = []
sit_sot_inner_slices = [] sit_sot_inner_slices = []
sit_sot_inner_outputs = [] sit_sot_inner_outputs = []
sit_sot_return_steps = {} sit_sot_return_steps = OrderedDict()
sit_sot_rightOrder = [] sit_sot_rightOrder = []
# go through outputs picking up time slices as needed # go through outputs picking up time slices as needed
...@@ -777,7 +778,7 @@ def scan(fn, ...@@ -777,7 +778,7 @@ def scan(fn,
# as non sequences at the end of our args # as non sequences at the end of our args
fake_nonseqs = [x.type() for x in non_seqs] fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = scan_utils.clone(outputs, fake_outputs = scan_utils.clone(outputs,
replace=dict(zip(non_seqs, replace=OrderedDict(zip(non_seqs,
fake_nonseqs))) fake_nonseqs)))
all_inputs = itertools.ifilter( all_inputs = itertools.ifilter(
lambda x: (isinstance(x, gof.Variable) and lambda x: (isinstance(x, gof.Variable) and
...@@ -825,7 +826,7 @@ def scan(fn, ...@@ -825,7 +826,7 @@ def scan(fn,
n_outs = len(dummy_f.maker.outputs) n_outs = len(dummy_f.maker.outputs)
if as_while: if as_while:
n_outs = n_outs - 1 n_outs = n_outs - 1
outs_info = [dict() for x in xrange(n_outs)] outs_info = [OrderedDict() for x in xrange(n_outs)]
## Step 5.1 Outputs with taps different then -1 ## Step 5.1 Outputs with taps different then -1
...@@ -839,7 +840,7 @@ def scan(fn, ...@@ -839,7 +840,7 @@ def scan(fn,
sit_sot_inner_outputs.append(outputs[i]) sit_sot_inner_outputs.append(outputs[i])
## Step 5.3 Outputs that correspond to update rules of shared variables ## Step 5.3 Outputs that correspond to update rules of shared variables
givens = {} givens = OrderedDict()
n_shared_outs = 0 n_shared_outs = 0
shared_scan_inputs = [] shared_scan_inputs = []
shared_inner_inputs = [] shared_inner_inputs = []
...@@ -879,7 +880,7 @@ def scan(fn, ...@@ -879,7 +880,7 @@ def scan(fn,
## Step 5.4 Outputs with no taps used in the input ## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0 n_nit_sot = 0
nit_sot_inner_outputs = [] nit_sot_inner_outputs = []
nit_sot_return_steps = {} nit_sot_return_steps = OrderedDict()
nit_sot_rightOrder = [] nit_sot_rightOrder = []
for i, out in enumerate(outs_info): for i, out in enumerate(outs_info):
if not 'taps' in out: if not 'taps' in out:
...@@ -902,7 +903,7 @@ def scan(fn, ...@@ -902,7 +903,7 @@ def scan(fn,
if (not isinstance(arg, SharedVariable) and if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))] not isinstance(arg, tensor.Constant))]
givens.update(dict(zip(other_scan_args, other_inner_args))) givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [arg.variable for arg other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and if (isinstance(arg.variable, SharedVariable) and
...@@ -911,7 +912,7 @@ def scan(fn, ...@@ -911,7 +912,7 @@ def scan(fn,
in dummy_f.maker.expanded_inputs in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and if (isinstance(arg.variable, SharedVariable) and
not arg.update)] not arg.update)]
givens.update(dict(zip(other_shared_scan_args, givens.update(OrdereDict(zip(other_shared_scan_args,
other_shared_inner_args))) other_shared_inner_args)))
## ##
...@@ -943,7 +944,7 @@ def scan(fn, ...@@ -943,7 +944,7 @@ def scan(fn,
# replace w with w_copy, where w is CudaNdarray # replace w with w_copy, where w is CudaNdarray
# and w_copy is TensorType. This is caused because shared # and w_copy is TensorType. This is caused because shared
# variables are put on GPU right aways >:| , # variables are put on GPU right aways >:| ,
new_givens = {} new_givens = OrderedDict()
for w, w_copy in givens.iteritems(): for w, w_copy in givens.iteritems():
if (isinstance(w.type, cuda.CudaNdarrayType) if (isinstance(w.type, cuda.CudaNdarrayType)
...@@ -962,7 +963,7 @@ def scan(fn, ...@@ -962,7 +963,7 @@ def scan(fn,
## ##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
info = {} info = OrderedDict()
info['tap_array'] = tap_array info['tap_array'] = tap_array
info['n_seqs'] = n_seqs info['n_seqs'] = n_seqs
...@@ -976,7 +977,7 @@ def scan(fn, ...@@ -976,7 +977,7 @@ def scan(fn,
info['truncate_gradient'] = truncate_gradient info['truncate_gradient'] = truncate_gradient
info['name'] = name info['name'] = name
info['mode'] = mode info['mode'] = mode
info['destroy_map'] = {} info['destroy_map'] = OrderedDict()
info['gpu'] = False info['gpu'] = False
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = profile info['profile'] = profile
...@@ -1012,7 +1013,7 @@ def scan(fn, ...@@ -1012,7 +1013,7 @@ def scan(fn,
### and so on ... ### and so on ...
## ##
update_map = Updates() update_map = OrderedUpdates()
def remove_dimensions(outs, steps_return, offsets=None): def remove_dimensions(outs, steps_return, offsets=None):
out_ls = [] out_ls = []
......
...@@ -23,7 +23,7 @@ import theano ...@@ -23,7 +23,7 @@ import theano
from theano.compile.pfunc import rebuild_collect_shared from theano.compile.pfunc import rebuild_collect_shared
from theano import gof from theano import gof
from theano import tensor, scalar from theano import tensor, scalar
from theano.gof.python25 import all from theano.gof.python25 import all, OrderedDict
from theano.tensor.basic import get_constant_value from theano.tensor.basic import get_constant_value
...@@ -181,12 +181,17 @@ def clone(output, ...@@ -181,12 +181,17 @@ def clone(output,
def get_updates_and_outputs(ls): def get_updates_and_outputs(ls):
""" """
This function tries to recognize the updates dictionary, the This function tries to recognize the updates OrderedDict, the
list of outputs and the stopping condition returned by the list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order lambda expression and arrange them in a predefined order
WRITEME: what is the type of ls? how is it formatted?
if it's not in the predefined order already, how does
this function know how to put it in that order?
""" """
def is_outputs(elem): def is_outputs(elem):
if (isinstance(elem, (list, tuple)) and if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])): all([isinstance(x, theano.Variable) for x in elem])):
...@@ -197,6 +202,8 @@ def get_updates_and_outputs(ls): ...@@ -197,6 +202,8 @@ def get_updates_and_outputs(ls):
def is_updates(elem): def is_updates(elem):
if isinstance(elem, dict): if isinstance(elem, dict):
# Make sure the updates will be applied in a deterministic order
assert 'Ordered' in str(type(elem))
return True return True
# Dictionaries can be given as lists of tuples # Dictionaries can be given as lists of tuples
if (isinstance(elem, (list, tuple)) and if (isinstance(elem, (list, tuple)) and
...@@ -240,12 +247,13 @@ def get_updates_and_outputs(ls): ...@@ -240,12 +247,13 @@ def get_updates_and_outputs(ls):
'variables (or `theano.scan_module.until` objects for ' 'variables (or `theano.scan_module.until` objects for '
'conditions). In particular if you need to use constant ' 'conditions). In particular if you need to use constant '
'values, you can use `tensor.constant` to turn them into ' 'values, you can use `tensor.constant` to turn them into '
'Theano variables.') 'Theano variables.')
if is_outputs(ls): if is_outputs(ls):
return None, _list(ls), {} return None, _list(ls), OrderedDict()
if is_updates(ls): if is_updates(ls):
return None, [], dict(ls) return None, [], OrderedDict(ls)
error_msg = ('Scan cannot parse the return value of your lambda ' error_msg = ('Scan cannot parse the return value of your lambda '
'expression, which is: %s' % (ls,)) 'expression, which is: %s' % (ls,))
if not isinstance(ls, (list, tuple)): if not isinstance(ls, (list, tuple)):
...@@ -258,16 +266,16 @@ def get_updates_and_outputs(ls): ...@@ -258,16 +266,16 @@ def get_updates_and_outputs(ls):
if len(ls) == 2: if len(ls) == 2:
if is_outputs(ls[0]): if is_outputs(ls[0]):
if is_updates(ls[1]): if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1])) return (None, _list(ls[0]), OrderedDict(ls[1]))
elif is_condition(ls[1]): elif is_condition(ls[1]):
return (ls[1].condition, _list(ls[0]), {}) return (ls[1].condition, _list(ls[0]), OrderedDict())
else: else:
raise ValueError(error_msg) raise ValueError(error_msg)
elif is_updates(ls[0]): elif is_updates(ls[0]):
if is_outputs(ls[1]): if is_outputs(ls[1]):
raise ValueError(deprecation_msg) raise ValueError(deprecation_msg)
elif is_condition(ls[1]): elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0])) return (ls[1].condition, [], OrderedDict(ls[0]))
else: else:
raise ValueError(error_msg) raise ValueError(error_msg)
else: else:
...@@ -276,7 +284,7 @@ def get_updates_and_outputs(ls): ...@@ -276,7 +284,7 @@ def get_updates_and_outputs(ls):
if is_outputs(ls[0]): if is_outputs(ls[0]):
if is_updates(ls[1]): if is_updates(ls[1]):
if is_condition(ls[2]): if is_condition(ls[2]):
return (ls[2].condition, _list(ls[0]), dict(ls[1])) return (ls[2].condition, _list(ls[0]), OrderedDict(ls[1]))
else: else:
raise ValueError(error_msg) raise ValueError(error_msg)
else: else:
......
...@@ -8,23 +8,25 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>" ...@@ -8,23 +8,25 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
from theano.gof.python25 import OrderedDict
from theano.compile.sharedvalue import SharedVariable from theano.compile.sharedvalue import SharedVariable
import logging import logging
logger = logging.getLogger('theano.updates') logger = logging.getLogger('theano.updates')
# Must be an OrderedDict or updates will be applied in a non-deterministic order
class Updates(dict): class OrderedUpdates(OrderedDict):
""" """
Dict-like mapping from SharedVariable keys to their new values. Dict-like mapping from SharedVariable keys to their new values.
This mapping supports the use of the "+" operator for the union of updates. This mapping supports the use of the "+" operator for the union of updates.
""" """
def __init__(self, *key, **kwargs): def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs) ret = super(OrderedUpdates, self).__init__(*key, **kwargs)
for key in self: for key in self:
if not isinstance(key, SharedVariable): if not isinstance(key, SharedVariable):
raise TypeError( raise TypeError(
'Updates keys must inherit from SharedVariable', 'OrderedUpdates keys must inherit from SharedVariable',
key) key)
return ret return ret
...@@ -38,9 +40,9 @@ class Updates(dict): ...@@ -38,9 +40,9 @@ class Updates(dict):
# value. Should it be cast to a GPU value right away? Should # value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately? # literals be transformed into constants immediately?
return super(Updates, self).__setitem__(key, value) return super(OrderedUpdates, self).__setitem__(key, value)
else: else:
raise TypeError('Updates keys must inherit from SharedVariable', raise TypeError('OrderedUpdates keys must inherit from SharedVariable',
key) key)
def update(self, other): def update(self, other):
...@@ -52,13 +54,13 @@ class Updates(dict): ...@@ -52,13 +54,13 @@ class Updates(dict):
self[key] = val # __setitem__ does type-checking self[key] = val # __setitem__ does type-checking
def __add__(self, other): def __add__(self, other):
rval = Updates() rval = OrderedUpdates()
rval.update(self) rval.update(self)
rval.update(other) rval.update(other)
return rval return rval
def __radd__(other, self): def __radd__(other, self):
rval = Updates() rval = OrderedUpdates()
rval.update(other) rval.update(other)
rval.update(self) rval.update(self)
return rval return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论