提交 dff25cb4 authored 作者: Caglar's avatar Caglar

changed syntax for dict and few cosmetic changes.

上级 cbb487d7
...@@ -229,7 +229,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -229,7 +229,7 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs) local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict({v:k for k,v in enumerate(local_fgraph.outputs)}) local_fgraph_outs_map = dict([(v, k) for k,v in enumerate(local_fgraph.outputs)])
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
...@@ -250,7 +250,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -250,7 +250,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict({v:k for k,v in enumerate(inner_non_seqs)}) inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -275,7 +275,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -275,7 +275,7 @@ class PushOutNonSeqScan(gof.Optimizer):
not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node # and we didn't already looked at this node
not nd in to_remove_set): nd not in to_remove_set):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
...@@ -308,7 +308,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -308,7 +308,7 @@ class PushOutNonSeqScan(gof.Optimizer):
y_place_holder = scan_utils.safe_new(y, '_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
nto_replace = add_to_replace(y, nto_replace) nto_replace = add_to_replace(y, nto_replace)
replace_with_in.append(y_place_holder) replace_with_in.append(y_place_holder)
assert type(y) == type(nw_outer_node.outputs[idx]) assert isinstance(y, type(nw_outer_node.outputs[idx]))
replace_with_out.append(nw_outer_node.outputs[idx]) replace_with_out.append(nw_outer_node.outputs[idx])
changed = True changed = True
...@@ -421,7 +421,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -421,7 +421,7 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs) local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict({v:k for k,v in enumerate(local_fgraph.outputs)}) local_fgraph_outs_map = dict([(v,k) for k,v in enumerate(local_fgraph.outputs)])
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
...@@ -445,12 +445,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -445,12 +445,12 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = dict({v:k for k,v in enumerate(inner_non_seqs)}) inner_non_seqs_map = dict([(v,k) for k,v in enumerate(inner_non_seqs)])
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = dict({v:k for k,v in enumerate(inner_seqs)}) inner_seqs_map = dict([(v,k) for k,v in enumerate(inner_seqs)])
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -564,7 +564,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -564,7 +564,7 @@ class PushOutSeqScan(gof.Optimizer):
to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes] to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep_set = set(to_keep) to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.items():
if (out in to_keep_set if (out in to_keep_set
and out.owner not in existent_nodes_set and out.owner not in existent_nodes_set
# If types are different, conversion Op will be inserted, # If types are different, conversion Op will be inserted,
...@@ -613,7 +613,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -613,7 +613,7 @@ class PushOutSeqScan(gof.Optimizer):
not op.outer_mitmot(node)): not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = OrderedDict()
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.items():
if out in local_fgraph_outs_set: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]] x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx] _y = replace_with_out[idx]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论