提交 e456487d authored 作者: Caglar's avatar Caglar

various refactoring and fixes.

上级 61b39397
...@@ -61,6 +61,7 @@ import logging ...@@ -61,6 +61,7 @@ import logging
import copy import copy
from sys import maxsize from sys import maxsize
import numpy import numpy
from itertools import chain
import theano import theano
from theano import tensor from theano import tensor
...@@ -162,8 +163,8 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -162,8 +163,8 @@ def remove_constants_and_unused_inputs_scan(node):
index = node.inputs.index(identical_seqs[0]) - 1 index = node.inputs.index(identical_seqs[0]) - 1
givens[op_ins[idx]] = op_ins[index] givens[op_ins[idx]] = op_ins[index]
else: else:
nw_inner += [op_ins[idx]] nw_inner.append(op_ins[idx])
nw_outer += [node_inp] nw_outer.append(node_inp)
nw_n_seqs = len(nw_inner) nw_n_seqs = len(nw_inner)
# Add outputs stuff # Add outputs stuff
...@@ -185,8 +186,9 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -185,8 +186,9 @@ def remove_constants_and_unused_inputs_scan(node):
if identical_nonseq_idx: if identical_nonseq_idx:
givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]] givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]]
else: else:
nw_inner_nonseq += [nw_in] nw_inner_nonseq.append(nw_in)
nw_outer_nonseq += [nw_out] nw_outer_nonseq.append(nw_out)
nw_inner.extend(nw_inner_nonseq) nw_inner.extend(nw_inner_nonseq)
nw_outer.extend(nw_outer_nonseq) nw_outer.extend(nw_outer_nonseq)
...@@ -229,8 +231,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -229,8 +231,7 @@ class PushOutNonSeqScan(gof.Optimizer):
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
to_remove = [] to_remove_set = set({})
to_replace_set = set({}) to_replace_set = set({})
to_replace_map = {} to_replace_map = {}
nto_replace = 0 nto_replace = 0
...@@ -245,7 +246,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -245,7 +246,7 @@ class PushOutNonSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)} inner_non_seqs_map = dict({v:k for k,v in enumerate(inner_non_seqs)})
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -261,7 +262,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -261,7 +262,7 @@ class PushOutNonSeqScan(gof.Optimizer):
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (all([(x in inner_non_seqs_set) or if (all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or (x.owner in to_remove_set) or
isinstance(x, tensor.Constant) isinstance(x, tensor.Constant)
for x in nd.inputs]) and for x in nd.inputs]) and
# we can do this because the assumption is that a # we can do this because the assumption is that a
...@@ -270,11 +271,11 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -270,11 +271,11 @@ class PushOutNonSeqScan(gof.Optimizer):
not isinstance(nd.op, theano.compile.ViewOp) and not isinstance(nd.op, theano.compile.ViewOp) and
not isinstance(nd.op, theano.compile.DeepCopyOp) and not isinstance(nd.op, theano.compile.DeepCopyOp) and
# and we didn't already looked at this node # and we didn't already looked at this node
not nd in to_remove): not nd in to_remove_set):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
to_remove.append(nd) to_remove_set.update([nd])
outside_ins = [] outside_ins = []
for x in nd.inputs: for x in nd.inputs:
if x in inner_non_seqs_set: if x in inner_non_seqs_set:
...@@ -321,10 +322,9 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -321,10 +322,9 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_replace_with_in = [] clean_replace_with_in = []
clean_replace_with_out = [] clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph_topo existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove] if nd not in to_remove_set]
to_keep = []
for nd in existent_nodes: to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
to_keep += nd.inputs
to_keep_set = set(to_keep) to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.iteritems():
...@@ -368,14 +368,14 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -368,14 +368,14 @@ class PushOutNonSeqScan(gof.Optimizer):
remove=[node], remove=[node],
reason='scanOp_pushout_nonseqs_ops') reason='scanOp_pushout_nonseqs_ops')
return True return True
elif to_keep == []: elif not to_keep:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = OrderedDict()
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs: if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph.outputs.index(out)]
y = replace_with_out[idx] y = replace_with_out[idx]
shape = [y.shape[idx] for idx in xrange(y.ndim)] shape = [shp for shp in y.shape]
replace_with[x] = tensor.alloc(y, replace_with[x] = tensor.alloc(y,
node.inputs[0], node.inputs[0],
*shape) *shape)
...@@ -419,8 +419,8 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -419,8 +419,8 @@ class PushOutSeqScan(gof.Optimizer):
local_fgraph_topo = local_fgraph.toposort() local_fgraph_topo = local_fgraph.toposort()
max_iterations = 2 * len(local_fgraph_topo) + 3 max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
to_remove = []
to_remove_set = set({})
to_replace_set = set({}) to_replace_set = set({})
to_replace_map = {} to_replace_map = {}
nto_replace = 0 nto_replace = 0
...@@ -436,12 +436,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -436,12 +436,12 @@ class PushOutSeqScan(gof.Optimizer):
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs) inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)} inner_non_seqs_map = dict({v:k for k,v in enumerate(inner_non_seqs)})
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs) inner_seqs_set = set(inner_seqs)
inner_seqs_map = {v:k for k, v in enumerate(inner_seqs)} inner_seqs_map = dict({v:k for k, v in enumerate(inner_seqs)})
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
...@@ -454,12 +454,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -454,12 +454,12 @@ class PushOutSeqScan(gof.Optimizer):
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Elemwise) and if (isinstance(nd.op, theano.tensor.Elemwise) and
all([(x in inner_non_seqs_set) or all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or (x.owner in to_remove_set) or
isinstance(x, tensor.Constant) or isinstance(x, tensor.Constant) or
(x in inner_seqs_set) (x in inner_seqs_set)
for x in nd.inputs]) and for x in nd.inputs]) and
not nd in to_remove): not nd in to_remove_set):
to_remove.append(nd) to_remove_set.update([nd])
outside_ins = [] outside_ins = []
depends_on_seqs = False depends_on_seqs = False
...@@ -507,9 +507,9 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -507,9 +507,9 @@ class PushOutSeqScan(gof.Optimizer):
elif (isinstance(nd.op, theano.tensor.DimShuffle) and elif (isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs_set or (nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove) and nd.inputs[0].owner in to_remove_set) and
not nd in to_remove): not nd in to_remove_set):
to_remove.append(nd) to_remove_set.update([nd])
x = nd.inputs[0] x = nd.inputs[0]
if x in inner_seqs_set: if x in inner_seqs_set:
outside_ins = outer_seqs[inner_seqs_map[x]] outside_ins = outer_seqs[inner_seqs_map[x]]
...@@ -549,10 +549,8 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -549,10 +549,8 @@ class PushOutSeqScan(gof.Optimizer):
clean_replace_with_out = [] clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph_topo existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove] if nd not in to_remove_set]
to_keep = [] to_keep = []; [to_keep.extend(nd.inputs) for nd in existent_nodes]
for nd in existent_nodes:
to_keep.extend(nd.inputs)
to_keep_set = set(to_keep) to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems(): for out, idx in to_replace_map.iteritems():
...@@ -599,7 +597,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -599,7 +597,7 @@ class PushOutSeqScan(gof.Optimizer):
remove=[node], remove=[node],
reason='scanOp_pushout_seqs_ops') reason='scanOp_pushout_seqs_ops')
return True return True
elif (to_keep == [] and elif (not to_keep and
not op.as_while and not op.as_while and
not op.outer_mitmot(node)): not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论