提交 61b39397 authored 作者: Caglar's avatar Caglar

Added the easy optimizations for the scan.

上级 d7d722fa
......@@ -225,18 +225,33 @@ class PushOutNonSeqScan(gof.Optimizer):
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
max_iterations = 2 * len(local_fgraph.toposort()) + 3
local_fgraph_topo = local_fgraph.toposort()
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0
to_remove = []
to_replace = []
to_replace_set = set({})
to_replace_map = {}
nto_replace = 0
def add_to_replace(y, nto_replace):
to_replace_set.update([y])
to_replace_map[y] = nto_replace
return nto_replace + 1
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs)
......@@ -244,8 +259,8 @@ class PushOutNonSeqScan(gof.Optimizer):
counts += 1
changed = False
for nd in local_fgraph.toposort():
if (numpy.all([(x in inner_non_seqs) or
for nd in local_fgraph_topo:
if (all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
......@@ -262,12 +277,12 @@ class PushOutNonSeqScan(gof.Optimizer):
to_remove.append(nd)
outside_ins = []
for x in nd.inputs:
if x in inner_non_seqs:
_idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]]
elif x in to_replace:
outside_ins += [
replace_with_out[to_replace.index(x)]]
if x in inner_non_seqs_set:
_idx = inner_non_seqs_map[x]
outside_ins.append(outer_non_seqs[_idx])
elif x in to_replace_set:
outside_ins.append(replace_with_out[to_replace_map[x]])
elif isinstance(x, theano.Constant):
outside_ins += [x.clone()]
else:
......@@ -286,12 +301,12 @@ class PushOutNonSeqScan(gof.Optimizer):
# Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
nto_replace = add_to_replace(y, nto_replace)
replace_with_in.append(y_place_holder)
assert type(y) == type(nw_outer_node.outputs[idx])
replace_with_out += [nw_outer_node.outputs[idx]]
replace_with_out.append(nw_outer_node.outputs[idx])
changed = True
if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_non_seq_operations`.'
......@@ -305,20 +320,22 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort()
existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx, out in enumerate(to_replace):
if (out in to_keep
to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems():
if (out in to_keep_set
and out.owner not in existent_nodes
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
and replace_with_in[idx].type == out.type):
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx])
clean_replace_with_out.append(replace_with_out[idx])
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
......@@ -354,7 +371,7 @@ class PushOutNonSeqScan(gof.Optimizer):
elif to_keep == []:
# Nothing in the inner graph should be kept
replace_with = OrderedDict()
for idx, out in enumerate(to_replace):
for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
y = replace_with_out[idx]
......@@ -399,18 +416,33 @@ class PushOutSeqScan(gof.Optimizer):
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
max_iterations = 2 * len(local_fgraph.toposort()) + 3
local_fgraph_topo = local_fgraph.toposort()
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0
to_remove = []
to_replace = []
to_replace_set = set({})
to_replace_map = {}
nto_replace = 0
def add_to_replace(y, nto_replace):
to_replace_set.update([y])
to_replace_map[y] = nto_replace
return nto_replace + 1
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs)
inner_seqs_map = {v:k for k, v in enumerate(inner_seqs)}
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs)
......@@ -419,12 +451,12 @@ class PushOutSeqScan(gof.Optimizer):
counts += 1
changed = False
for nd in local_fgraph.toposort():
for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Elemwise) and
numpy.all([(x in inner_non_seqs) or
all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant) or
(x in inner_seqs)
(x in inner_seqs_set)
for x in nd.inputs]) and
not nd in to_remove):
to_remove.append(nd)
......@@ -433,17 +465,17 @@ class PushOutSeqScan(gof.Optimizer):
for x in nd.inputs:
if x in inner_non_seqs:
_idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]]
elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]]
_idx = inner_non_seqs_map[x]
outside_ins.append(outer_non_seqs[_idx])
elif x in inner_seqs_set:
outside_ins.append(outer_seqs[inner_seqs_map[x]])
depends_on_seqs = True
elif x in to_replace:
outside_ins += [replace_with_out[
to_replace.index(x)]]
elif x in to_replace_set:
outside_ins.append(replace_with_out[
to_replace_map[x]])
depends_on_seqs = True
elif isinstance(x, theano.Constant):
outside_ins += [x.clone()]
outside_ins.append(x.clone())
else:
raise Exception(
('Error in the `scan_pushout_seq_'
......@@ -466,24 +498,23 @@ class PushOutSeqScan(gof.Optimizer):
# Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [nw_outer_node.outputs[idx]]
nto_replace = add_to_replace(y, nto_replace)
replace_with_in.append(y_place_holder)
replace_with_out.append(nw_outer_node.outputs[idx])
changed = True
elif (isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs or
(nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove) and
not nd in to_remove):
to_remove.append(nd)
x = nd.inputs[0]
if x in inner_seqs:
outside_ins = outer_seqs[inner_seqs.index(x)]
elif x in to_replace:
outside_ins = replace_with_out[to_replace.index(x)]
if x in inner_seqs_set:
outside_ins = outer_seqs[inner_seqs_map[x]]
elif x in to_replace_set:
outside_ins = replace_with_out[to_replace_map[x]]
new_ord = (0,)
for old_ord in nd.op.new_order:
if (old_ord == 'x'):
......@@ -493,9 +524,10 @@ class PushOutSeqScan(gof.Optimizer):
new_outer = outside_ins.dimshuffle(new_ord)
y = nd.outputs[0]
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [new_outer]
nto_replace = add_to_replace(y, nto_replace)
replace_with_in.append(y_place_holder)
replace_with_out.append(new_outer)
if hasattr(new_outer.tag, "test_value"):
new_sh = new_outer.tag.test_value.shape
ref_sh = (outside_ins.tag.test_value.shape[0],)
......@@ -516,20 +548,23 @@ class PushOutSeqScan(gof.Optimizer):
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort()
existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx, out in enumerate(to_replace):
if (out in to_keep
to_keep.extend(nd.inputs)
to_keep_set = set(to_keep)
for out, idx in to_replace_map.iteritems():
if (out in to_keep_set
and out.owner not in existent_nodes
# If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop.
and replace_with_in[idx].type == out.type):
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
clean_to_replace.append(out)
clean_replace_with_in.append(replace_with_in[idx])
clean_replace_with_out.append(replace_with_out[idx])
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
......@@ -542,8 +577,9 @@ class PushOutSeqScan(gof.Optimizer):
if isinstance(repl_out, theano.Constant):
repl_in = repl_out.clone()
else:
nw_inner += [repl_in]
nw_outer += [repl_out]
nw_inner.append(repl_in)
nw_outer.append(repl_out)
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
......@@ -568,7 +604,7 @@ class PushOutSeqScan(gof.Optimizer):
not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept
replace_with = OrderedDict()
for idx, out in enumerate(to_replace):
for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
_y = replace_with_out[idx]
......@@ -631,15 +667,15 @@ class PushOutScanOutput(gof.Optimizer):
# Use scan_args to parse the inputs and outputs of scan for ease of
# use
args = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
op.inputs, op.outputs, op.info)
local_fgraph = gof.FunctionGraph(args.inner_inputs,
args.inner_outputs,
clone=False)
new_scan_node = None
for nd in local_fgraph.toposort():
local_fgraph_topo = local_fgraph.toposort()
for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in args.inner_out_nit_sot):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论