提交 f36f9d6c authored 作者: lamblin's avatar lamblin

Merge pull request #1254 from nouiz/StochasticOrder_Scan

Try to fix Stochastic order scan by using OrderedDict more often.
......@@ -23,7 +23,7 @@ import theano
from theano.compile import function, Param, Out
from theano import compile
from theano import gradient
from theano.gof.python25 import any
from theano.gof.python25 import any, OrderedDict
from theano.gof import PureOp, Apply
from theano import gof
from theano.tensor import TensorType
......@@ -426,9 +426,9 @@ class Scan(PureOp):
if not type(self) == type(other):
return False
if not 'destroy_map' in self.info:
self.info['destroy_map'] = {}
self.info['destroy_map'] = OrderedDict()
if not 'destroy_map' in other.info:
other.info['destroy_map'] = {}
other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'name',
'as_while', 'n_mit_sot', 'destroy_map',
......@@ -472,7 +472,7 @@ class Scan(PureOp):
name = 'for'
aux_txt = '%s'
if getattr(self, 'destroy_map', None) is None:
self.destroy_map = {}
self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if (sorted(self.destroy_map.keys()) == \
......@@ -852,7 +852,7 @@ class Scan(PureOp):
pos = [(-self.mintaps[idx]) % store_steps[idx] for idx
in xrange(self.n_outs + self.n_nit_sot)]
if not getattr(self, 'destroy_map', None):
self.destroy_map = {}
self.destroy_map = OrderedDict()
# 2.1 Create storage space for outputs
for idx in xrange(self.n_outs):
if idx in self.destroy_map:
......@@ -1138,7 +1138,7 @@ class Scan(PureOp):
# Non-sequences have a direct equivalent from self.inputs in
# node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = {}
out_equivalent = OrderedDict()
for in_ns, out_ns in izip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns
if self.as_while:
......@@ -1257,7 +1257,7 @@ class Scan(PureOp):
def compute_gradient(y, g_y, diff_inputs):
rval = []
gmp = {}
gmp = OrderedDict()
consider_inps = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs]
for x in consider_inps:
......@@ -1695,7 +1695,7 @@ class Scan(PureOp):
new_tap_array = mitmot_inp_taps + [[-1] for k in
xrange(n_sitsot_outs)]
info = {}
info = OrderedDict()
info['n_seqs'] = len(outer_inp_seqs)
info['n_mit_sot'] = 0
info['tap_array'] = new_tap_array
......@@ -1709,7 +1709,7 @@ class Scan(PureOp):
info['n_nit_sot'] = n_nit_sot
info['as_while'] = False
info['profile'] = self.profile
info['destroy_map'] = {}
info['destroy_map'] = OrderedDict()
if self.name:
info['name'] = 'grad_of_' + self.name
else:
......@@ -1852,7 +1852,7 @@ class Scan(PureOp):
# The only exception is the eval point for the number of sequences, and
# evan point for the number of nit_sot which I think should just be
# ignored (?)
info = {}
info = OrderedDict()
info['n_seqs'] = self.n_seqs * 2
info['n_mit_sot'] = self.n_mit_sot * 2
info['n_sit_sot'] = self.n_sit_sot * 2
......@@ -1869,7 +1869,7 @@ class Scan(PureOp):
info['name'] = None
info['mode'] = self.mode
info['mit_mot_out_slices'] = self.mit_mot_out_slices * 2
info['destroy_map'] = {}
info['destroy_map'] = OrderedDict()
new_tap_array = []
b = 0
e = self.n_mit_mot
......
......@@ -20,7 +20,7 @@ import theano
from theano import tensor
from theano.tensor import opt, get_scalar_constant_value
from theano import gof
from theano.gof.python25 import maxsize, any
from theano.gof.python25 import maxsize, any, OrderedDict
from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb
......@@ -86,7 +86,7 @@ def remove_constants_and_unused_inputs_scan(node):
out_stuff_outer = node.inputs[1 + op.n_seqs:st]
# To replace constants in the outer graph by clones in the inner graph
givens = {}
givens = OrderedDict()
# All the inputs of the inner graph of the new scan
nw_inner = []
# Same for the outer graph, initialized w/ number of steps
......@@ -257,7 +257,7 @@ class PushOutNonSeqScan(gof.Optimizer):
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
givens = OrderedDict()
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
......@@ -284,7 +284,7 @@ class PushOutNonSeqScan(gof.Optimizer):
return True
elif to_keep == []:
# Nothing in the inner graph should be kept
replace_with = {}
replace_with = OrderedDict()
for idx, out in enumerate(to_replace):
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
......@@ -439,7 +439,7 @@ class PushOutSeqScan(gof.Optimizer):
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
givens = OrderedDict()
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
......@@ -471,7 +471,7 @@ class PushOutSeqScan(gof.Optimizer):
not op.as_while and
not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept
replace_with = gof.python25.OrderedDict()
replace_with = OrderedDict()
for idx, out in enumerate(to_replace):
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
......@@ -529,7 +529,7 @@ class ScanInplaceOptimizer(Optimizer):
for pos in xrange(n_outs):
info = copy.deepcopy(op.info)
if not 'destroy_map' in info:
info['destroy_map'] = {}
info['destroy_map'] = OrderedDict()
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
......@@ -600,7 +600,7 @@ class ScanSaveMem(gof.Optimizer):
# Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of
# dictionary.
shape_of = {}
shape_of = OrderedDict()
# 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to
......@@ -923,7 +923,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {}
inv_compress_map = OrderedDict()
for k, v in compress_map.items():
inv_compress_map[v] = k
......@@ -1053,7 +1053,7 @@ class ScanMerge(gof.Optimizer):
else:
as_while = False
info = {}
info = OrderedDict()
info['tap_array'] = []
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
......@@ -1228,7 +1228,7 @@ def has_duplicates(l):
def make_equiv(lo, li):
"""builds a dictionary of equivalences between inner inputs based on
the equivalence of their corresponding outer inputs."""
seeno = {}
seeno = OrderedDict()
left = []
right = []
for o, i in zip(lo, li):
......@@ -1248,7 +1248,7 @@ def scan_merge_inouts(node):
a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
inp_equiv = {}
inp_equiv = OrderedDict()
if has_duplicates(a.outer_in_seqs):
new_outer_seqs = []
......@@ -1310,7 +1310,7 @@ def scan_merge_inouts(node):
left += _left
right += _right
if has_duplicates(na.outer_in_mit_mot):
seen = {}
seen = OrderedDict()
for omm, imm, _sl in zip(na.outer_in_mit_mot,
na.inner_in_mit_mot, na.mit_mot_in_slices):
sl = tuple(_sl)
......@@ -1322,7 +1322,7 @@ def scan_merge_inouts(node):
seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot):
seen = {}
seen = OrderedDict()
for oms, ims, _sl in zip(na.outer_in_mit_sot,
na.inner_in_mit_sot,
na.mit_sot_in_slices):
......
......@@ -221,7 +221,8 @@ def get_updates_and_outputs(ls):
def is_updates(elem):
if isinstance(elem, dict):
# Make sure the updates will be applied in a deterministic order
if not isinstance(elem, gof.python25.OrderedDict):
if (not isinstance(elem, gof.python25.OrderedDict) and
len(elem) > 1):
warnings.warn("Expected OrderedDict or OrderedUpdates, got "\
+ str(type(elem)) + ". This can make your script non-"
"deterministic.")
......@@ -511,7 +512,7 @@ class Validator(object):
if invalid is None:
invalid = []
if valid_equivalent is None:
valid_equivalent = {}
valid_equivalent = OrderedDict()
# Nodes that are valid to have in the graph computing outputs
self.valid = set(valid)
......@@ -624,7 +625,7 @@ def compress_outs(op, not_required, inputs):
means removing its inputs from the inner funciton and from the
node inputs, and changing the dictionary.
'''
info = {}
info = OrderedDict()
info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs']
info['n_mit_mot'] = 0
......@@ -644,7 +645,7 @@ def compress_outs(op, not_required, inputs):
op_inputs = op.inputs[:op.n_seqs]
op_outputs = []
node_inputs = inputs[:op.n_seqs + 1]
map_old_new = {}
map_old_new = OrderedDict()
offset = 0
ni_offset = op.n_seqs + 1
......@@ -779,7 +780,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
if tag is None:
tag = ''
nw_inputs = [safe_new(x, tag) for x in inputs]
givens = {}
givens = OrderedDict()
for nw_x, x in izip(nw_inputs, inputs):
givens[x] = nw_x
allinputs = theano.gof.graph.inputs(outputs)
......@@ -899,7 +900,7 @@ class scan_args(object):
p += n_shared_outs
q += n_shared_outs
self.other_info = dict()
self.other_info = OrderedDict()
for k in ('truncate_gradient', 'name', 'mode', 'destroy_map',
'gpu', 'as_while', 'profile'):
if k in info:
......@@ -933,7 +934,7 @@ class scan_args(object):
self.outer_out_nit_sot +
self.outer_out_shared))
info = property(lambda self: dict(
info = property(lambda self: OrderedDict(
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
......@@ -998,4 +999,4 @@ def forced_replace(out, x, y):
rval += traverse(inp, x)
return rval
to_replace = traverse(out, x)
return clone(out, replace=dict((v, y) for v in to_replace))
return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论