提交 e456487d authored 作者: Caglar's avatar Caglar

various refactoring and fixes.

上级 61b39397
......@@ -61,6 +61,7 @@ import logging
import copy
from sys import maxsize
import numpy
from itertools import chain
import theano
from theano import tensor
......@@ -162,8 +163,8 @@ def remove_constants_and_unused_inputs_scan(node):
index = node.inputs.index(identical_seqs[0]) - 1
givens[op_ins[idx]] = op_ins[index]
else:
nw_inner += [op_ins[idx]]
nw_outer += [node_inp]
nw_inner.append(op_ins[idx])
nw_outer.append(node_inp)
nw_n_seqs = len(nw_inner)
# Add outputs stuff
......@@ -185,8 +186,9 @@ def remove_constants_and_unused_inputs_scan(node):
if identical_nonseq_idx:
givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]]
else:
nw_inner_nonseq += [nw_in]
nw_outer_nonseq += [nw_out]
nw_inner_nonseq.append(nw_in)
nw_outer_nonseq.append(nw_out)
nw_inner.extend(nw_inner_nonseq)
nw_outer.extend(nw_outer_nonseq)
......@@ -229,8 +231,7 @@ class PushOutNonSeqScan(gof.Optimizer):
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0
to_remove = []
to_remove_set = set({})
to_replace_set = set({})
to_replace_map = {}
nto_replace = 0
......@@ -245,7 +246,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)}
inner_non_seqs_map = dict({v:k for k,v in enumerate(inner_non_seqs)})
outer_non_seqs = op.outer_non_seqs(node.inputs)
......@@ -261,7 +262,7 @@ class PushOutNonSeqScan(gof.Optimizer):
for nd in local_fgraph_topo:
if (all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or
(x.owner in to_remove_set) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
# we can do this because the assumption is that a
......@@ -270,11 +271,11 @@ class PushOutNonSeqScan(gof.Optimizer):
not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove):
not nd in to_remove_set):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
to_remove.append(nd)
to_remove_set.update([nd])
outside_ins = []
for x in nd.inputs:
if x in inner_non_seqs_set:
......@@ -321,10 +322,9 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
if nd not in to_remove_set]
to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems():
......@@ -368,14 +368,14 @@ class PushOutNonSeqScan(gof.Optimizer):
remove=[node],
reason='scanOp_pushout_nonseqs_ops')
return True
elif to_keep == []:
elif not to_keep:
# Nothing in the inner graph should be kept
replace_with = OrderedDict()
for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
y = replace_with_out[idx]
shape = [y.shape[idx] for idx in xrange(y.ndim)]
shape = [shp for shp in y.shape]
replace_with[x] = tensor.alloc(y,
node.inputs[0],
*shape)
......@@ -419,8 +419,8 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = local_fgraph.toposort()
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0
to_remove = []
to_remove_set = set({})
to_replace_set = set({})
to_replace_map = {}
nto_replace = 0
......@@ -436,12 +436,12 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)}
inner_non_seqs_map = dict({v:k for k,v in enumerate(inner_non_seqs)})
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs)
inner_seqs_map = {v:k for k, v in enumerate(inner_seqs)}
inner_seqs_map = dict({v:k for k, v in enumerate(inner_seqs)})
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
......@@ -454,12 +454,12 @@ class PushOutSeqScan(gof.Optimizer):
for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Elemwise) and
all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or
(x.owner in to_remove_set) or
isinstance(x, tensor.Constant) or
(x in inner_seqs_set)
for x in nd.inputs]) and
not nd in to_remove):
to_remove.append(nd)
not nd in to_remove_set):
to_remove_set.update([nd])
outside_ins = []
depends_on_seqs = False
......@@ -507,9 +507,9 @@ class PushOutSeqScan(gof.Optimizer):
elif (isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove) and
not nd in to_remove):
to_remove.append(nd)
nd.inputs[0].owner in to_remove_set) and
not nd in to_remove_set):
to_remove_set.update([nd])
x = nd.inputs[0]
if x in inner_seqs_set:
outside_ins = outer_seqs[inner_seqs_map[x]]
......@@ -549,10 +549,8 @@ class PushOutSeqScan(gof.Optimizer):
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep.extend(nd.inputs)
if nd not in to_remove_set]
to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems():
......@@ -599,7 +597,7 @@ class PushOutSeqScan(gof.Optimizer):
remove=[node],
reason='scanOp_pushout_seqs_ops')
return True
elif (to_keep == [] and
elif (not to_keep and
not op.as_while and
not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论