提交 f36f9d6c authored 作者: lamblin's avatar lamblin

Merge pull request #1254 from nouiz/StochasticOrder_Scan

Try to fix Stochastic order scan by using OrderedDict more often.
...@@ -23,7 +23,7 @@ import theano ...@@ -23,7 +23,7 @@ import theano
from theano.compile import function, Param, Out from theano.compile import function, Param, Out
from theano import compile from theano import compile
from theano import gradient from theano import gradient
from theano.gof.python25 import any from theano.gof.python25 import any, OrderedDict
from theano.gof import PureOp, Apply from theano.gof import PureOp, Apply
from theano import gof from theano import gof
from theano.tensor import TensorType from theano.tensor import TensorType
...@@ -426,9 +426,9 @@ class Scan(PureOp): ...@@ -426,9 +426,9 @@ class Scan(PureOp):
if not type(self) == type(other): if not type(self) == type(other):
return False return False
if not 'destroy_map' in self.info: if not 'destroy_map' in self.info:
self.info['destroy_map'] = {} self.info['destroy_map'] = OrderedDict()
if not 'destroy_map' in other.info: if not 'destroy_map' in other.info:
other.info['destroy_map'] = {} other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile', keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'name', 'n_seqs', 'tap_array', 'name',
'as_while', 'n_mit_sot', 'destroy_map', 'as_while', 'n_mit_sot', 'destroy_map',
...@@ -472,7 +472,7 @@ class Scan(PureOp): ...@@ -472,7 +472,7 @@ class Scan(PureOp):
name = 'for' name = 'for'
aux_txt = '%s' aux_txt = '%s'
if getattr(self, 'destroy_map', None) is None: if getattr(self, 'destroy_map', None) is None:
self.destroy_map = {} self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0: if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace # Check if all outputs are inplace
if (sorted(self.destroy_map.keys()) == \ if (sorted(self.destroy_map.keys()) == \
...@@ -852,7 +852,7 @@ class Scan(PureOp): ...@@ -852,7 +852,7 @@ class Scan(PureOp):
pos = [(-self.mintaps[idx]) % store_steps[idx] for idx pos = [(-self.mintaps[idx]) % store_steps[idx] for idx
in xrange(self.n_outs + self.n_nit_sot)] in xrange(self.n_outs + self.n_nit_sot)]
if not getattr(self, 'destroy_map', None): if not getattr(self, 'destroy_map', None):
self.destroy_map = {} self.destroy_map = OrderedDict()
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
for idx in xrange(self.n_outs): for idx in xrange(self.n_outs):
if idx in self.destroy_map: if idx in self.destroy_map:
...@@ -1138,7 +1138,7 @@ class Scan(PureOp): ...@@ -1138,7 +1138,7 @@ class Scan(PureOp):
# Non-sequences have a direct equivalent from self.inputs in # Non-sequences have a direct equivalent from self.inputs in
# node.inputs # node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):] inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = {} out_equivalent = OrderedDict()
for in_ns, out_ns in izip(inner_non_sequences, node.inputs[offset:]): for in_ns, out_ns in izip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns out_equivalent[in_ns] = out_ns
if self.as_while: if self.as_while:
...@@ -1257,7 +1257,7 @@ class Scan(PureOp): ...@@ -1257,7 +1257,7 @@ class Scan(PureOp):
def compute_gradient(y, g_y, diff_inputs): def compute_gradient(y, g_y, diff_inputs):
rval = [] rval = []
gmp = {} gmp = OrderedDict()
consider_inps = [x for x in theano.gof.graph.inputs([y]) consider_inps = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs] if x in diff_inputs]
for x in consider_inps: for x in consider_inps:
...@@ -1695,7 +1695,7 @@ class Scan(PureOp): ...@@ -1695,7 +1695,7 @@ class Scan(PureOp):
new_tap_array = mitmot_inp_taps + [[-1] for k in new_tap_array = mitmot_inp_taps + [[-1] for k in
xrange(n_sitsot_outs)] xrange(n_sitsot_outs)]
info = {} info = OrderedDict()
info['n_seqs'] = len(outer_inp_seqs) info['n_seqs'] = len(outer_inp_seqs)
info['n_mit_sot'] = 0 info['n_mit_sot'] = 0
info['tap_array'] = new_tap_array info['tap_array'] = new_tap_array
...@@ -1709,7 +1709,7 @@ class Scan(PureOp): ...@@ -1709,7 +1709,7 @@ class Scan(PureOp):
info['n_nit_sot'] = n_nit_sot info['n_nit_sot'] = n_nit_sot
info['as_while'] = False info['as_while'] = False
info['profile'] = self.profile info['profile'] = self.profile
info['destroy_map'] = {} info['destroy_map'] = OrderedDict()
if self.name: if self.name:
info['name'] = 'grad_of_' + self.name info['name'] = 'grad_of_' + self.name
else: else:
...@@ -1852,7 +1852,7 @@ class Scan(PureOp): ...@@ -1852,7 +1852,7 @@ class Scan(PureOp):
# The only exception is the eval point for the number of sequences, and # The only exception is the eval point for the number of sequences, and
# evan point for the number of nit_sot which I think should just be # evan point for the number of nit_sot which I think should just be
# ignored (?) # ignored (?)
info = {} info = OrderedDict()
info['n_seqs'] = self.n_seqs * 2 info['n_seqs'] = self.n_seqs * 2
info['n_mit_sot'] = self.n_mit_sot * 2 info['n_mit_sot'] = self.n_mit_sot * 2
info['n_sit_sot'] = self.n_sit_sot * 2 info['n_sit_sot'] = self.n_sit_sot * 2
...@@ -1869,7 +1869,7 @@ class Scan(PureOp): ...@@ -1869,7 +1869,7 @@ class Scan(PureOp):
info['name'] = None info['name'] = None
info['mode'] = self.mode info['mode'] = self.mode
info['mit_mot_out_slices'] = self.mit_mot_out_slices * 2 info['mit_mot_out_slices'] = self.mit_mot_out_slices * 2
info['destroy_map'] = {} info['destroy_map'] = OrderedDict()
new_tap_array = [] new_tap_array = []
b = 0 b = 0
e = self.n_mit_mot e = self.n_mit_mot
......
...@@ -20,7 +20,7 @@ import theano ...@@ -20,7 +20,7 @@ import theano
from theano import tensor from theano import tensor
from theano.tensor import opt, get_scalar_constant_value from theano.tensor import opt, get_scalar_constant_value
from theano import gof from theano import gof
from theano.gof.python25 import maxsize, any from theano.gof.python25 import maxsize, any, OrderedDict
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler, InconsistencyError from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.compile import optdb from theano.compile import optdb
...@@ -86,7 +86,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -86,7 +86,7 @@ def remove_constants_and_unused_inputs_scan(node):
out_stuff_outer = node.inputs[1 + op.n_seqs:st] out_stuff_outer = node.inputs[1 + op.n_seqs:st]
# To replace constants in the outer graph by clones in the inner graph # To replace constants in the outer graph by clones in the inner graph
givens = {} givens = OrderedDict()
# All the inputs of the inner graph of the new scan # All the inputs of the inner graph of the new scan
nw_inner = [] nw_inner = []
# Same for the outer graph, initialized w/ number of steps # Same for the outer graph, initialized w/ number of steps
...@@ -257,7 +257,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -257,7 +257,7 @@ class PushOutNonSeqScan(gof.Optimizer):
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
givens = {} givens = OrderedDict()
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace, for to_repl, repl_in, repl_out in zip(clean_to_replace,
...@@ -284,7 +284,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -284,7 +284,7 @@ class PushOutNonSeqScan(gof.Optimizer):
return True return True
elif to_keep == []: elif to_keep == []:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = {} replace_with = OrderedDict()
for idx, out in enumerate(to_replace): for idx, out in enumerate(to_replace):
if out in local_fgraph.outputs: if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph.outputs.index(out)]
...@@ -439,7 +439,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -439,7 +439,7 @@ class PushOutSeqScan(gof.Optimizer):
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
givens = {} givens = OrderedDict()
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace, for to_repl, repl_in, repl_out in zip(clean_to_replace,
...@@ -471,7 +471,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -471,7 +471,7 @@ class PushOutSeqScan(gof.Optimizer):
not op.as_while and not op.as_while and
not op.outer_mitmot(node)): not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = gof.python25.OrderedDict() replace_with = OrderedDict()
for idx, out in enumerate(to_replace): for idx, out in enumerate(to_replace):
if out in local_fgraph.outputs: if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph.outputs.index(out)]
...@@ -529,7 +529,7 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -529,7 +529,7 @@ class ScanInplaceOptimizer(Optimizer):
for pos in xrange(n_outs): for pos in xrange(n_outs):
info = copy.deepcopy(op.info) info = copy.deepcopy(op.info)
if not 'destroy_map' in info: if not 'destroy_map' in info:
info['destroy_map'] = {} info['destroy_map'] = OrderedDict()
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']] info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
...@@ -600,7 +600,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -600,7 +600,7 @@ class ScanSaveMem(gof.Optimizer):
# Each access to shape_of is in a try..except block in order to # Each access to shape_of is in a try..except block in order to
# use a default version when the variable is not in the shape_of # use a default version when the variable is not in the shape_of
# dictionary. # dictionary.
shape_of = {} shape_of = OrderedDict()
# 1. Initialization of variables # 1. Initialization of variables
# Note 1) We do not actually care about outputs representing shared # Note 1) We do not actually care about outputs representing shared
# variables (those have no intermediate values) so it is safer to # variables (those have no intermediate values) so it is safer to
...@@ -923,7 +923,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -923,7 +923,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs # 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \ (inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs) scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {} inv_compress_map = OrderedDict()
for k, v in compress_map.items(): for k, v in compress_map.items():
inv_compress_map[v] = k inv_compress_map[v] = k
...@@ -1053,7 +1053,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1053,7 +1053,7 @@ class ScanMerge(gof.Optimizer):
else: else:
as_while = False as_while = False
info = {} info = OrderedDict()
info['tap_array'] = [] info['tap_array'] = []
info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes]) info['n_seqs'] = sum([nd.op.n_seqs for nd in nodes])
info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes]) info['n_mit_mot'] = sum([nd.op.n_mit_mot for nd in nodes])
...@@ -1228,7 +1228,7 @@ def has_duplicates(l): ...@@ -1228,7 +1228,7 @@ def has_duplicates(l):
def make_equiv(lo, li): def make_equiv(lo, li):
"""builds a dictionary of equivalences between inner inputs based on """builds a dictionary of equivalences between inner inputs based on
the equivalence of their corresponding outer inputs.""" the equivalence of their corresponding outer inputs."""
seeno = {} seeno = OrderedDict()
left = [] left = []
right = [] right = []
for o, i in zip(lo, li): for o, i in zip(lo, li):
...@@ -1248,7 +1248,7 @@ def scan_merge_inouts(node): ...@@ -1248,7 +1248,7 @@ def scan_merge_inouts(node):
a = scan_args(node.inputs, node.outputs, a = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info) node.op.inputs, node.op.outputs, node.op.info)
inp_equiv = {} inp_equiv = OrderedDict()
if has_duplicates(a.outer_in_seqs): if has_duplicates(a.outer_in_seqs):
new_outer_seqs = [] new_outer_seqs = []
...@@ -1310,7 +1310,7 @@ def scan_merge_inouts(node): ...@@ -1310,7 +1310,7 @@ def scan_merge_inouts(node):
left += _left left += _left
right += _right right += _right
if has_duplicates(na.outer_in_mit_mot): if has_duplicates(na.outer_in_mit_mot):
seen = {} seen = OrderedDict()
for omm, imm, _sl in zip(na.outer_in_mit_mot, for omm, imm, _sl in zip(na.outer_in_mit_mot,
na.inner_in_mit_mot, na.mit_mot_in_slices): na.inner_in_mit_mot, na.mit_mot_in_slices):
sl = tuple(_sl) sl = tuple(_sl)
...@@ -1322,7 +1322,7 @@ def scan_merge_inouts(node): ...@@ -1322,7 +1322,7 @@ def scan_merge_inouts(node):
seen[(omm, sl)] = imm seen[(omm, sl)] = imm
if has_duplicates(na.outer_in_mit_sot): if has_duplicates(na.outer_in_mit_sot):
seen = {} seen = OrderedDict()
for oms, ims, _sl in zip(na.outer_in_mit_sot, for oms, ims, _sl in zip(na.outer_in_mit_sot,
na.inner_in_mit_sot, na.inner_in_mit_sot,
na.mit_sot_in_slices): na.mit_sot_in_slices):
......
...@@ -221,7 +221,8 @@ def get_updates_and_outputs(ls): ...@@ -221,7 +221,8 @@ def get_updates_and_outputs(ls):
def is_updates(elem): def is_updates(elem):
if isinstance(elem, dict): if isinstance(elem, dict):
# Make sure the updates will be applied in a deterministic order # Make sure the updates will be applied in a deterministic order
if not isinstance(elem, gof.python25.OrderedDict): if (not isinstance(elem, gof.python25.OrderedDict) and
len(elem) > 1):
warnings.warn("Expected OrderedDict or OrderedUpdates, got "\ warnings.warn("Expected OrderedDict or OrderedUpdates, got "\
+ str(type(elem)) + ". This can make your script non-" + str(type(elem)) + ". This can make your script non-"
"deterministic.") "deterministic.")
...@@ -511,7 +512,7 @@ class Validator(object): ...@@ -511,7 +512,7 @@ class Validator(object):
if invalid is None: if invalid is None:
invalid = [] invalid = []
if valid_equivalent is None: if valid_equivalent is None:
valid_equivalent = {} valid_equivalent = OrderedDict()
# Nodes that are valid to have in the graph computing outputs # Nodes that are valid to have in the graph computing outputs
self.valid = set(valid) self.valid = set(valid)
...@@ -624,7 +625,7 @@ def compress_outs(op, not_required, inputs): ...@@ -624,7 +625,7 @@ def compress_outs(op, not_required, inputs):
means removing its inputs from the inner funciton and from the means removing its inputs from the inner funciton and from the
node inputs, and changing the dictionary. node inputs, and changing the dictionary.
''' '''
info = {} info = OrderedDict()
info['tap_array'] = [] info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs'] info['n_seqs'] = op.info['n_seqs']
info['n_mit_mot'] = 0 info['n_mit_mot'] = 0
...@@ -644,7 +645,7 @@ def compress_outs(op, not_required, inputs): ...@@ -644,7 +645,7 @@ def compress_outs(op, not_required, inputs):
op_inputs = op.inputs[:op.n_seqs] op_inputs = op.inputs[:op.n_seqs]
op_outputs = [] op_outputs = []
node_inputs = inputs[:op.n_seqs + 1] node_inputs = inputs[:op.n_seqs + 1]
map_old_new = {} map_old_new = OrderedDict()
offset = 0 offset = 0
ni_offset = op.n_seqs + 1 ni_offset = op.n_seqs + 1
...@@ -779,7 +780,7 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -779,7 +780,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
if tag is None: if tag is None:
tag = '' tag = ''
nw_inputs = [safe_new(x, tag) for x in inputs] nw_inputs = [safe_new(x, tag) for x in inputs]
givens = {} givens = OrderedDict()
for nw_x, x in izip(nw_inputs, inputs): for nw_x, x in izip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
allinputs = theano.gof.graph.inputs(outputs) allinputs = theano.gof.graph.inputs(outputs)
...@@ -899,7 +900,7 @@ class scan_args(object): ...@@ -899,7 +900,7 @@ class scan_args(object):
p += n_shared_outs p += n_shared_outs
q += n_shared_outs q += n_shared_outs
self.other_info = dict() self.other_info = OrderedDict()
for k in ('truncate_gradient', 'name', 'mode', 'destroy_map', for k in ('truncate_gradient', 'name', 'mode', 'destroy_map',
'gpu', 'as_while', 'profile'): 'gpu', 'as_while', 'profile'):
if k in info: if k in info:
...@@ -933,7 +934,7 @@ class scan_args(object): ...@@ -933,7 +934,7 @@ class scan_args(object):
self.outer_out_nit_sot + self.outer_out_nit_sot +
self.outer_out_shared)) self.outer_out_shared))
info = property(lambda self: dict( info = property(lambda self: OrderedDict(
n_seqs=len(self.outer_in_seqs), n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot), n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot), n_mit_sot=len(self.outer_in_mit_sot),
...@@ -998,4 +999,4 @@ def forced_replace(out, x, y): ...@@ -998,4 +999,4 @@ def forced_replace(out, x, y):
rval += traverse(inp, x) rval += traverse(inp, x)
return rval return rval
to_replace = traverse(out, x) to_replace = traverse(out, x)
return clone(out, replace=dict((v, y) for v in to_replace)) return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论