提交 cbb487d7 authored 作者: Caglar's avatar Caglar

more datastructure related changes.

上级 1e86bc88
...@@ -61,6 +61,7 @@ import logging ...@@ -61,6 +61,7 @@ import logging
import copy import copy
from sys import maxsize from sys import maxsize
import numpy import numpy
from itertools import izip
import theano import theano
from theano import tensor from theano import tensor
...@@ -227,6 +228,8 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -227,6 +228,8 @@ class PushOutNonSeqScan(gof.Optimizer):
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict({v:k for k,v in enumerate(local_fgraph.outputs)})
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
...@@ -294,7 +297,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -294,7 +297,7 @@ class PushOutNonSeqScan(gof.Optimizer):
'which is not allowed to move. Report ' 'which is not allowed to move. Report '
'this on theano-users list'), x) 'this on theano-users list'), x)
outside_ins = [x.type.filter_variable(y) for x, y in outside_ins = [x.type.filter_variable(y) for x, y in
zip(nd.inputs, outside_ins)] izip(nd.inputs, outside_ins)]
# Do not call make_node for test_value # Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins, nw_outer_node = nd.op(*outside_ins,
...@@ -327,7 +330,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -327,7 +330,7 @@ class PushOutNonSeqScan(gof.Optimizer):
to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes] to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep_set = set(to_keep) to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.items():
if (out in to_keep_set if (out in to_keep_set
and out.owner not in existent_nodes_set and out.owner not in existent_nodes_set
# If types are different, conversion Op will be inserted, # If types are different, conversion Op will be inserted,
...@@ -342,7 +345,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -342,7 +345,7 @@ class PushOutNonSeqScan(gof.Optimizer):
givens = OrderedDict() givens = OrderedDict()
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace, for to_repl, repl_in, repl_out in izip(clean_to_replace,
clean_replace_with_in, clean_replace_with_in,
clean_replace_with_out): clean_replace_with_out):
if isinstance(repl_out, theano.Constant): if isinstance(repl_out, theano.Constant):
...@@ -371,9 +374,9 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -371,9 +374,9 @@ class PushOutNonSeqScan(gof.Optimizer):
elif not to_keep: elif not to_keep:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = OrderedDict()
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.items():
if out in local_fgraph.outputs: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph_outs_map[out]]
y = replace_with_out[idx] y = replace_with_out[idx]
shape = [shp for shp in y.shape] shape = [shp for shp in y.shape]
replace_with[x] = tensor.alloc(y, replace_with[x] = tensor.alloc(y,
...@@ -417,6 +420,9 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -417,6 +420,9 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = local_fgraph.toposort()
local_fgraph_outs_set = set(local_fgraph.outputs)
local_fgraph_outs_map = dict({v:k for k,v in enumerate(local_fgraph.outputs)})
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
...@@ -538,6 +544,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -538,6 +544,7 @@ class PushOutSeqScan(gof.Optimizer):
assert new_sh == ref_sh assert new_sh == ref_sh
changed = True changed = True
if counts >= max_iterations: if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_seq_operations`.' raise Exception('Error in the `scan_pushout_seq_operations`.'
' The optimization exhausted the maximal number ' ' The optimization exhausted the maximal number '
...@@ -573,7 +580,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -573,7 +580,7 @@ class PushOutSeqScan(gof.Optimizer):
givens = OrderedDict() givens = OrderedDict()
nw_outer = [] nw_outer = []
nw_inner = [] nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace, for to_repl, repl_in, repl_out in izip(clean_to_replace,
clean_replace_with_in, clean_replace_with_in,
clean_replace_with_out): clean_replace_with_out):
if isinstance(repl_out, theano.Constant): if isinstance(repl_out, theano.Constant):
...@@ -607,8 +614,8 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -607,8 +614,8 @@ class PushOutSeqScan(gof.Optimizer):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = OrderedDict()
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs: if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph_outs_map[out]]
_y = replace_with_out[idx] _y = replace_with_out[idx]
ls = local_fgraph.outputs ls = local_fgraph.outputs
if out in op.inner_mitsot_outs(ls): if out in op.inner_mitsot_outs(ls):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论