提交 6447ce7e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed determinism issues in scan

上级 d6d377d6
......@@ -88,7 +88,7 @@ from printing import \
import scan_module
from scan_module import scan, map, reduce, foldl, foldr, clone
from updates import Updates
from updates import OrderedUpdates
import tensor
import scalar
......
......@@ -17,10 +17,11 @@ import numpy
from theano.compile import SharedVariable, function
from theano import compile
from theano import gof
from theano.gof.python25 import OrderedDict
from theano.tensor import opt
from theano import tensor
from theano import config
from theano.updates import Updates
from theano.updates import OrderedUpdates
from theano.scan_module import scan_op
......@@ -147,7 +148,7 @@ def scan(fn,
n_seqs = len(seqs)
n_outs = len(outs_info)
return_steps = {}
return_steps = OrderedDict()
# wrap outputs info in a dictionary if they are not already in one
for i in xrange(n_outs):
if outs_info[i] is not None:
......@@ -242,7 +243,7 @@ def scan(fn,
mit_sot_inner_inputs = []
mit_sot_inner_slices = []
mit_sot_inner_outputs = []
mit_sot_return_steps = {}
mit_sot_return_steps = OrderedDict()
mit_sot_tap_array = []
mit_sot_rightOrder = []
......@@ -251,7 +252,7 @@ def scan(fn,
sit_sot_inner_inputs = []
sit_sot_inner_slices = []
sit_sot_inner_outputs = []
sit_sot_return_steps = {}
sit_sot_return_steps = OrderedDict()
sit_sot_rightOrder = []
nit_sot_steps = []
# go through outputs picking up time slices as needed
......@@ -398,7 +399,8 @@ def scan(fn,
not isinstance(arg, tensor.Constant))]
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
lambda_result = fn(*args)
condition, outputs, updates = scan_utils.get_updates_and_outputs(lambda_result)
if condition is not None:
as_while = True
else:
......@@ -464,6 +466,11 @@ def scan(fn,
dummy_outs = outputs
if condition is not None:
dummy_outs.append(condition)
# If we use a regular dict here, the results are non-deterministic
assert isinstance(updates, (list, tuple)) or (isinstance(updates, dict) and \
'Ordered' in str(type(updates)))
dummy_f = function(dummy_args,
dummy_outs,
updates=updates,
......@@ -508,7 +515,7 @@ def scan(fn,
sit_sot_inner_outputs.append(outputs[i])
## Step 5.3 Outputs that correspond to update rules of shared variables
givens = {}
givens = OrderedDict()
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
......@@ -527,7 +534,7 @@ def scan(fn,
## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
nit_sot_inner_outputs = []
nit_sot_return_steps = {}
nit_sot_return_steps = OrderedDict()
nit_sot_rightOrder = []
for i, out in enumerate(outs_info):
if not 'taps' in out:
......@@ -582,7 +589,7 @@ def scan(fn,
shared_inner_outputs)
if condition is not None:
inner_outs.append(condition)
new_givens = {}
new_givens = OrderedDict()
for w, w_copy in givens.iteritems():
new_givens[w] = w.type.filter_variable(w_copy)
......@@ -593,7 +600,7 @@ def scan(fn,
##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
info = {}
info = OrderedDict()
info['tap_array'] = tap_array
info['n_seqs'] = n_seqs
......@@ -607,7 +614,7 @@ def scan(fn,
info['truncate_gradient'] = -1
info['name'] = name
info['mode'] = mode
info['destroy_map'] = {}
info['destroy_map'] = OrderedDict()
info['inplace'] = False
info['gpu'] = False
info['as_while'] = as_while
......@@ -641,7 +648,7 @@ def scan(fn,
### and so on ...
##
update_map = Updates()
update_map = OrderedUpdates()
offset = n_mit_mot
offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array]
......@@ -675,4 +682,5 @@ def scan(fn,
elif len(scan_out_list) == 0:
scan_out_list = None
assert isinstance(update_map, dict) and 'Ordered' in str(type(update_map))
return (scan_out_list, update_map)
......@@ -52,8 +52,9 @@ from theano import gof
from theano.tensor import opt
from theano import tensor
from theano import config
from theano.updates import Updates
from theano.updates import OrderedUpdates
from theano.compile import ops
from theano.gof.python25 import OrderedDict
import scan_op
......@@ -376,11 +377,11 @@ def scan(fn,
n_seqs = len(seqs)
n_outs = len(outs_info)
return_steps = {}
return_steps = OrderedDict()
# wrap sequences in a dictionary if they are not already dictionaries
for i in xrange(n_seqs):
if not isinstance(seqs[i], dict):
seqs[i] = dict(input=seqs[i], taps=[0])
seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])])
elif seqs[i].get('taps', None):
seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
elif seqs[i].get('taps', True) is None:
......@@ -402,7 +403,7 @@ def scan(fn,
if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1
outs_info[i] = dict(initial=outs_info[i], taps=[-1])
outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])])
elif (not outs_info[i].get('initial', None) and
outs_info[i].get('taps', None)):
# ^ no initial state but taps provided
......@@ -421,8 +422,8 @@ def scan(fn,
outs_info[i]['taps'] = [-1]
else:
# if a None is provided as the output info we replace it
# with an empty dict() to simplify handling
outs_info[i] = dict()
# with an empty OrdereDict() to simplify handling
outs_info[i] = OrderedDict()
##
### Step 2. Generate inputs and outputs of the inner functions
......@@ -565,7 +566,7 @@ def scan(fn,
mit_sot_inner_inputs = []
mit_sot_inner_slices = []
mit_sot_inner_outputs = []
mit_sot_return_steps = {}
mit_sot_return_steps = OrderedDict()
mit_sot_tap_array = []
mit_sot_rightOrder = []
......@@ -574,7 +575,7 @@ def scan(fn,
sit_sot_inner_inputs = []
sit_sot_inner_slices = []
sit_sot_inner_outputs = []
sit_sot_return_steps = {}
sit_sot_return_steps = OrderedDict()
sit_sot_rightOrder = []
# go through outputs picking up time slices as needed
......@@ -777,7 +778,7 @@ def scan(fn,
# as non sequences at the end of our args
fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = scan_utils.clone(outputs,
replace=dict(zip(non_seqs,
replace=OrderedDict(zip(non_seqs,
fake_nonseqs)))
all_inputs = itertools.ifilter(
lambda x: (isinstance(x, gof.Variable) and
......@@ -825,7 +826,7 @@ def scan(fn,
n_outs = len(dummy_f.maker.outputs)
if as_while:
n_outs = n_outs - 1
outs_info = [dict() for x in xrange(n_outs)]
outs_info = [OrderedDict() for x in xrange(n_outs)]
## Step 5.1 Outputs with taps different then -1
......@@ -839,7 +840,7 @@ def scan(fn,
sit_sot_inner_outputs.append(outputs[i])
## Step 5.3 Outputs that correspond to update rules of shared variables
givens = {}
givens = OrderedDict()
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
......@@ -879,7 +880,7 @@ def scan(fn,
## Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
nit_sot_inner_outputs = []
nit_sot_return_steps = {}
nit_sot_return_steps = OrderedDict()
nit_sot_rightOrder = []
for i, out in enumerate(outs_info):
if not 'taps' in out:
......@@ -902,7 +903,7 @@ def scan(fn,
if (not isinstance(arg, SharedVariable) and
not isinstance(arg, tensor.Constant))]
givens.update(dict(zip(other_scan_args, other_inner_args)))
givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
......@@ -911,7 +912,7 @@ def scan(fn,
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
givens.update(dict(zip(other_shared_scan_args,
givens.update(OrdereDict(zip(other_shared_scan_args,
other_shared_inner_args)))
##
......@@ -943,7 +944,7 @@ def scan(fn,
# replace w with w_copy, where w is CudaNdarray
# and w_copy is TensorType. This is caused because shared
# variables are put on GPU right aways >:| ,
new_givens = {}
new_givens = OrderedDict()
for w, w_copy in givens.iteritems():
if (isinstance(w.type, cuda.CudaNdarrayType)
......@@ -962,7 +963,7 @@ def scan(fn,
##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
info = {}
info = OrderedDict()
info['tap_array'] = tap_array
info['n_seqs'] = n_seqs
......@@ -976,7 +977,7 @@ def scan(fn,
info['truncate_gradient'] = truncate_gradient
info['name'] = name
info['mode'] = mode
info['destroy_map'] = {}
info['destroy_map'] = OrderedDict()
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = profile
......@@ -1012,7 +1013,7 @@ def scan(fn,
### and so on ...
##
update_map = Updates()
update_map = OrderedUpdates()
def remove_dimensions(outs, steps_return, offsets=None):
out_ls = []
......
......@@ -23,7 +23,7 @@ import theano
from theano.compile.pfunc import rebuild_collect_shared
from theano import gof
from theano import tensor, scalar
from theano.gof.python25 import all
from theano.gof.python25 import all, OrderedDict
from theano.tensor.basic import get_constant_value
......@@ -181,12 +181,17 @@ def clone(output,
def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates dictionary, the
This function tries to recognize the updates OrderedDict, the
list of outputs and the stopping condition returned by the
lambda expression and arrange them in a predefined order
WRITEME: what is the type of ls? how is it formatted?
if it's not in the predefined order already, how does
this function know how to put it in that order?
"""
def is_outputs(elem):
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])):
......@@ -197,6 +202,8 @@ def get_updates_and_outputs(ls):
def is_updates(elem):
if isinstance(elem, dict):
# Make sure the updates will be applied in a deterministic order
assert 'Ordered' in str(type(elem))
return True
# Dictionaries can be given as lists of tuples
if (isinstance(elem, (list, tuple)) and
......@@ -240,12 +247,13 @@ def get_updates_and_outputs(ls):
'variables (or `theano.scan_module.until` objects for '
'conditions). In particular if you need to use constant '
'values, you can use `tensor.constant` to turn them into '
'Theano variables.')
'Theano variables.')
if is_outputs(ls):
return None, _list(ls), {}
return None, _list(ls), OrderedDict()
if is_updates(ls):
return None, [], dict(ls)
return None, [], OrderedDict(ls)
error_msg = ('Scan cannot parse the return value of your lambda '
'expression, which is: %s' % (ls,))
if not isinstance(ls, (list, tuple)):
......@@ -258,16 +266,16 @@ def get_updates_and_outputs(ls):
if len(ls) == 2:
if is_outputs(ls[0]):
if is_updates(ls[1]):
return (None, _list(ls[0]), dict(ls[1]))
return (None, _list(ls[0]), OrderedDict(ls[1]))
elif is_condition(ls[1]):
return (ls[1].condition, _list(ls[0]), {})
return (ls[1].condition, _list(ls[0]), OrderedDict())
else:
raise ValueError(error_msg)
elif is_updates(ls[0]):
if is_outputs(ls[1]):
raise ValueError(deprecation_msg)
elif is_condition(ls[1]):
return (ls[1].condition, [], dict(ls[0]))
return (ls[1].condition, [], OrderedDict(ls[0]))
else:
raise ValueError(error_msg)
else:
......@@ -276,7 +284,7 @@ def get_updates_and_outputs(ls):
if is_outputs(ls[0]):
if is_updates(ls[1]):
if is_condition(ls[2]):
return (ls[2].condition, _list(ls[0]), dict(ls[1]))
return (ls[2].condition, _list(ls[0]), OrderedDict(ls[1]))
else:
raise ValueError(error_msg)
else:
......
......@@ -8,23 +8,25 @@ __contact__ = "theano-dev <theano-dev@googlegroups.com>"
__docformat__ = "restructuredtext en"
from theano.gof.python25 import OrderedDict
from theano.compile.sharedvalue import SharedVariable
import logging
logger = logging.getLogger('theano.updates')
class Updates(dict):
# Must be an OrderedDict or updates will be applied in a non-deterministic order
class OrderedUpdates(OrderedDict):
"""
Dict-like mapping from SharedVariable keys to their new values.
This mapping supports the use of the "+" operator for the union of updates.
"""
def __init__(self, *key, **kwargs):
ret = super(Updates, self).__init__(*key, **kwargs)
ret = super(OrderedUpdates, self).__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
'Updates keys must inherit from SharedVariable',
'OrderedUpdates keys must inherit from SharedVariable',
key)
return ret
......@@ -38,9 +40,9 @@ class Updates(dict):
# value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately?
return super(Updates, self).__setitem__(key, value)
return super(OrderedUpdates, self).__setitem__(key, value)
else:
raise TypeError('Updates keys must inherit from SharedVariable',
raise TypeError('OrderedUpdates keys must inherit from SharedVariable',
key)
def update(self, other):
......@@ -52,13 +54,13 @@ class Updates(dict):
self[key] = val # __setitem__ does type-checking
def __add__(self, other):
rval = Updates()
rval = OrderedUpdates()
rval.update(self)
rval.update(other)
return rval
def __radd__(other, self):
rval = Updates()
rval = OrderedUpdates()
rval.update(other)
rval.update(self)
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论