提交 cbb487d7 authored 作者: Caglar's avatar Caglar

more datastructure related changes.

上级 1e86bc88
......@@ -61,6 +61,7 @@ import logging
import copy
from sys import maxsize
import numpy
from itertools import izip
import theano
from theano import tensor
......@@ -227,6 +228,8 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict({v:k for k,v in enumerate(local_fgraph.outputs)})
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0
......@@ -294,7 +297,7 @@ class PushOutNonSeqScan(gof.Optimizer):
'which is not allowed to move. Report '
'this on theano-users list'), x)
outside_ins = [x.type.filter_variable(y) for x, y in
zip(nd.inputs, outside_ins)]
izip(nd.inputs, outside_ins)]
# Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins,
......@@ -327,7 +330,7 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems():
for out, idx in to_replace_map.items():
if (out in to_keep_set
and out.owner not in existent_nodes_set
# If types are different, conversion Op will be inserted,
......@@ -342,7 +345,7 @@ class PushOutNonSeqScan(gof.Optimizer):
givens = OrderedDict()
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
for to_repl, repl_in, repl_out in izip(clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
......@@ -371,9 +374,9 @@ class PushOutNonSeqScan(gof.Optimizer):
elif not to_keep:
# Nothing in the inner graph should be kept
replace_with = OrderedDict()
for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
for out, idx in to_replace_map.items():
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
y = replace_with_out[idx]
shape = [shp for shp in y.shape]
replace_with[x] = tensor.alloc(y,
......@@ -417,6 +420,9 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict({v:k for k,v in enumerate(local_fgraph.outputs)})
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0
......@@ -538,6 +544,7 @@ class PushOutSeqScan(gof.Optimizer):
assert new_sh == ref_sh
changed = True
if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_seq_operations`.'
' The optimization exhausted the maximal number '
......@@ -573,7 +580,7 @@ class PushOutSeqScan(gof.Optimizer):
givens = OrderedDict()
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
for to_repl, repl_in, repl_out in izip(clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
......@@ -607,8 +614,8 @@ class PushOutSeqScan(gof.Optimizer):
# Nothing in the inner graph should be kept
replace_with = OrderedDict()
for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx]
ls = local_fgraph.outputs
if out in op.inner_mitsot_outs(ls):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论