提交 0d02e65e authored 作者: Frederic's avatar Frederic

more OrderedDict

上级 f4af10da
...@@ -492,7 +492,7 @@ class Validator(object): ...@@ -492,7 +492,7 @@ class Validator(object):
if invalid is None: if invalid is None:
invalid = [] invalid = []
if valid_equivalent is None: if valid_equivalent is None:
valid_equivalent = {} valid_equivalent = OrderedDict()
# Nodes that are valid to have in the graph computing outputs # Nodes that are valid to have in the graph computing outputs
self.valid = set(valid) self.valid = set(valid)
...@@ -605,7 +605,7 @@ def compress_outs(op, not_required, inputs): ...@@ -605,7 +605,7 @@ def compress_outs(op, not_required, inputs):
means removing its inputs from the inner funciton and from the means removing its inputs from the inner funciton and from the
node inputs, and changing the dictionary. node inputs, and changing the dictionary.
''' '''
info = {} info = OrderedDict()
info['tap_array'] = [] info['tap_array'] = []
info['n_seqs'] = op.info['n_seqs'] info['n_seqs'] = op.info['n_seqs']
info['n_mit_mot'] = 0 info['n_mit_mot'] = 0
...@@ -625,7 +625,7 @@ def compress_outs(op, not_required, inputs): ...@@ -625,7 +625,7 @@ def compress_outs(op, not_required, inputs):
op_inputs = op.inputs[:op.n_seqs] op_inputs = op.inputs[:op.n_seqs]
op_outputs = [] op_outputs = []
node_inputs = inputs[:op.n_seqs + 1] node_inputs = inputs[:op.n_seqs + 1]
map_old_new = {} map_old_new = OrderedDict()
offset = 0 offset = 0
ni_offset = op.n_seqs + 1 ni_offset = op.n_seqs + 1
...@@ -760,7 +760,7 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -760,7 +760,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
if tag is None: if tag is None:
tag = '' tag = ''
nw_inputs = [safe_new(x, tag) for x in inputs] nw_inputs = [safe_new(x, tag) for x in inputs]
givens = {} givens = OrderedDict()
for nw_x, x in izip(nw_inputs, inputs): for nw_x, x in izip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
allinputs = theano.gof.graph.inputs(outputs) allinputs = theano.gof.graph.inputs(outputs)
...@@ -880,7 +880,7 @@ class scan_args(object): ...@@ -880,7 +880,7 @@ class scan_args(object):
p += n_shared_outs p += n_shared_outs
q += n_shared_outs q += n_shared_outs
self.other_info = dict() self.other_info = OrderedDict()
for k in ('truncate_gradient', 'name', 'mode', 'destroy_map', for k in ('truncate_gradient', 'name', 'mode', 'destroy_map',
'gpu', 'as_while', 'profile'): 'gpu', 'as_while', 'profile'):
if k in info: if k in info:
...@@ -914,7 +914,7 @@ class scan_args(object): ...@@ -914,7 +914,7 @@ class scan_args(object):
self.outer_out_nit_sot + self.outer_out_nit_sot +
self.outer_out_shared)) self.outer_out_shared))
info = property(lambda self: dict( info = property(lambda self: OrderedDict(
n_seqs=len(self.outer_in_seqs), n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot), n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot), n_mit_sot=len(self.outer_in_mit_sot),
...@@ -979,4 +979,4 @@ def forced_replace(out, x, y): ...@@ -979,4 +979,4 @@ def forced_replace(out, x, y):
rval += traverse(inp, x) rval += traverse(inp, x)
return rval return rval
to_replace = traverse(out, x) to_replace = traverse(out, x)
return clone(out, replace=dict((v, y) for v in to_replace)) return clone(out, replace=OrderedDict((v, y) for v in to_replace))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论