提交 61b39397 authored 作者: Caglar's avatar Caglar

Added the easy optimizations for the scan.

上级 d7d722fa
...@@ -225,18 +225,33 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -225,18 +225,33 @@ class PushOutNonSeqScan(gof.Optimizer):
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
max_iterations = 2 * len(local_fgraph.toposort()) + 3 local_fgraph_topo = local_fgraph.toposort()
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
to_remove = [] to_remove = []
to_replace = []
to_replace_set = set({})
to_replace_map = {}
nto_replace = 0
def add_to_replace(y, nto_replace):
to_replace_set.update([y])
to_replace_map[y] = nto_replace
return nto_replace + 1
replace_with_in = [] replace_with_in = []
replace_with_out = [] replace_with_out = []
op = node.op op = node.op
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs) assert len(inner_seqs) == len(outer_seqs)
...@@ -244,8 +259,8 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -244,8 +259,8 @@ class PushOutNonSeqScan(gof.Optimizer):
counts += 1 counts += 1
changed = False changed = False
for nd in local_fgraph.toposort(): for nd in local_fgraph_topo:
if (numpy.all([(x in inner_non_seqs) or if (all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or (x.owner in to_remove) or
isinstance(x, tensor.Constant) isinstance(x, tensor.Constant)
for x in nd.inputs]) and for x in nd.inputs]) and
...@@ -262,12 +277,12 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -262,12 +277,12 @@ class PushOutNonSeqScan(gof.Optimizer):
to_remove.append(nd) to_remove.append(nd)
outside_ins = [] outside_ins = []
for x in nd.inputs: for x in nd.inputs:
if x in inner_non_seqs: if x in inner_non_seqs_set:
_idx = inner_non_seqs.index(x) _idx = inner_non_seqs_map[x]
outside_ins += [outer_non_seqs[_idx]] outside_ins.append(outer_non_seqs[_idx])
elif x in to_replace: elif x in to_replace_set:
outside_ins += [ outside_ins.append(replace_with_out[to_replace_map[x]])
replace_with_out[to_replace.index(x)]]
elif isinstance(x, theano.Constant): elif isinstance(x, theano.Constant):
outside_ins += [x.clone()] outside_ins += [x.clone()]
else: else:
...@@ -286,12 +301,12 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -286,12 +301,12 @@ class PushOutNonSeqScan(gof.Optimizer):
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y] nto_replace = add_to_replace(y, nto_replace)
replace_with_in += [y_place_holder] replace_with_in.append(y_place_holder)
assert type(y) == type(nw_outer_node.outputs[idx]) assert type(y) == type(nw_outer_node.outputs[idx])
replace_with_out += [nw_outer_node.outputs[idx]] replace_with_out.append(nw_outer_node.outputs[idx])
changed = True changed = True
if counts >= max_iterations: if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_non_seq_operations`.' raise Exception('Error in the `scan_pushout_non_seq_operations`.'
...@@ -305,20 +320,22 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -305,20 +320,22 @@ class PushOutNonSeqScan(gof.Optimizer):
clean_to_replace = [] clean_to_replace = []
clean_replace_with_in = [] clean_replace_with_in = []
clean_replace_with_out = [] clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort() existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove] if nd not in to_remove]
to_keep = [] to_keep = []
for nd in existent_nodes: for nd in existent_nodes:
to_keep += nd.inputs to_keep += nd.inputs
for idx, out in enumerate(to_replace): to_keep_set = set(to_keep)
if (out in to_keep
for out, idx in to_replace_map.iteritems():
if (out in to_keep_set
and out.owner not in existent_nodes and out.owner not in existent_nodes
# If types are different, conversion Op will be inserted, # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
and replace_with_in[idx].type == out.type): and replace_with_in[idx].type == out.type):
clean_to_replace += [out] clean_to_replace.append(out)
clean_replace_with_in += [replace_with_in[idx]] clean_replace_with_in.append(replace_with_in[idx])
clean_replace_with_out += [replace_with_out[idx]] clean_replace_with_out.append(replace_with_out[idx])
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
...@@ -354,7 +371,7 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -354,7 +371,7 @@ class PushOutNonSeqScan(gof.Optimizer):
elif to_keep == []: elif to_keep == []:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = OrderedDict()
for idx, out in enumerate(to_replace): for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs: if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph.outputs.index(out)]
y = replace_with_out[idx] y = replace_with_out[idx]
...@@ -399,18 +416,33 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -399,18 +416,33 @@ class PushOutSeqScan(gof.Optimizer):
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False) local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
max_iterations = 2 * len(local_fgraph.toposort()) + 3 local_fgraph_topo = local_fgraph.toposort()
max_iterations = 2 * len(local_fgraph_topo) + 3
counts = 0 counts = 0
to_remove = [] to_remove = []
to_replace = []
to_replace_set = set({})
to_replace_map = {}
nto_replace = 0
def add_to_replace(y, nto_replace):
to_replace_set.update([y])
to_replace_map[y] = nto_replace
return nto_replace + 1
replace_with_in = [] replace_with_in = []
replace_with_out = [] replace_with_out = []
op = node.op op = node.op
# Construct the list of non_sequences to simplify a few things # Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs) inner_non_seqs = op.inner_non_seqs(clean_inputs)
inner_non_seqs_set = set(inner_non_seqs)
inner_non_seqs_map = {v:k for k, v in enumerate(inner_non_seqs)}
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs) inner_seqs = op.inner_seqs(clean_inputs)
inner_seqs_set = set(inner_seqs)
inner_seqs_map = {v:k for k, v in enumerate(inner_seqs)}
outer_seqs = op.outer_seqs(node.inputs) outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs) assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs) assert len(inner_seqs) == len(outer_seqs)
...@@ -419,12 +451,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -419,12 +451,12 @@ class PushOutSeqScan(gof.Optimizer):
counts += 1 counts += 1
changed = False changed = False
for nd in local_fgraph.toposort(): for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Elemwise) and if (isinstance(nd.op, theano.tensor.Elemwise) and
numpy.all([(x in inner_non_seqs) or all([(x in inner_non_seqs_set) or
(x.owner in to_remove) or (x.owner in to_remove) or
isinstance(x, tensor.Constant) or isinstance(x, tensor.Constant) or
(x in inner_seqs) (x in inner_seqs_set)
for x in nd.inputs]) and for x in nd.inputs]) and
not nd in to_remove): not nd in to_remove):
to_remove.append(nd) to_remove.append(nd)
...@@ -433,17 +465,17 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -433,17 +465,17 @@ class PushOutSeqScan(gof.Optimizer):
for x in nd.inputs: for x in nd.inputs:
if x in inner_non_seqs: if x in inner_non_seqs:
_idx = inner_non_seqs.index(x) _idx = inner_non_seqs_map[x]
outside_ins += [outer_non_seqs[_idx]] outside_ins.append(outer_non_seqs[_idx])
elif x in inner_seqs: elif x in inner_seqs_set:
outside_ins += [outer_seqs[inner_seqs.index(x)]] outside_ins.append(outer_seqs[inner_seqs_map[x]])
depends_on_seqs = True depends_on_seqs = True
elif x in to_replace: elif x in to_replace_set:
outside_ins += [replace_with_out[ outside_ins.append(replace_with_out[
to_replace.index(x)]] to_replace_map[x]])
depends_on_seqs = True depends_on_seqs = True
elif isinstance(x, theano.Constant): elif isinstance(x, theano.Constant):
outside_ins += [x.clone()] outside_ins.append(x.clone())
else: else:
raise Exception( raise Exception(
('Error in the `scan_pushout_seq_' ('Error in the `scan_pushout_seq_'
...@@ -466,24 +498,23 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -466,24 +498,23 @@ class PushOutSeqScan(gof.Optimizer):
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y] nto_replace = add_to_replace(y, nto_replace)
replace_with_in += [y_place_holder] replace_with_in.append(y_place_holder)
replace_with_out += [nw_outer_node.outputs[idx]] replace_with_out.append(nw_outer_node.outputs[idx])
changed = True changed = True
elif (isinstance(nd.op, theano.tensor.DimShuffle) and elif (isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs or (nd.inputs[0] in inner_seqs_set or
nd.inputs[0].owner in to_remove) and nd.inputs[0].owner in to_remove) and
not nd in to_remove): not nd in to_remove):
to_remove.append(nd) to_remove.append(nd)
x = nd.inputs[0] x = nd.inputs[0]
if x in inner_seqs: if x in inner_seqs_set:
outside_ins = outer_seqs[inner_seqs.index(x)] outside_ins = outer_seqs[inner_seqs_map[x]]
elif x in to_replace: elif x in to_replace_set:
outside_ins = replace_with_out[to_replace.index(x)] outside_ins = replace_with_out[to_replace_map[x]]
new_ord = (0,) new_ord = (0,)
for old_ord in nd.op.new_order: for old_ord in nd.op.new_order:
if (old_ord == 'x'): if (old_ord == 'x'):
...@@ -493,9 +524,10 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -493,9 +524,10 @@ class PushOutSeqScan(gof.Optimizer):
new_outer = outside_ins.dimshuffle(new_ord) new_outer = outside_ins.dimshuffle(new_ord)
y = nd.outputs[0] y = nd.outputs[0]
y_place_holder = scan_utils.safe_new(y, '_replace') y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y] nto_replace = add_to_replace(y, nto_replace)
replace_with_in += [y_place_holder] replace_with_in.append(y_place_holder)
replace_with_out += [new_outer] replace_with_out.append(new_outer)
if hasattr(new_outer.tag, "test_value"): if hasattr(new_outer.tag, "test_value"):
new_sh = new_outer.tag.test_value.shape new_sh = new_outer.tag.test_value.shape
ref_sh = (outside_ins.tag.test_value.shape[0],) ref_sh = (outside_ins.tag.test_value.shape[0],)
...@@ -516,20 +548,23 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -516,20 +548,23 @@ class PushOutSeqScan(gof.Optimizer):
clean_replace_with_in = [] clean_replace_with_in = []
clean_replace_with_out = [] clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort() existent_nodes = [nd for nd in local_fgraph_topo
if nd not in to_remove] if nd not in to_remove]
to_keep = [] to_keep = []
for nd in existent_nodes: for nd in existent_nodes:
to_keep += nd.inputs to_keep.extend(nd.inputs)
for idx, out in enumerate(to_replace): to_keep_set = set(to_keep)
if (out in to_keep
for out, idx in to_replace_map.iteritems():
if (out in to_keep_set
and out.owner not in existent_nodes and out.owner not in existent_nodes
# If types are different, conversion Op will be inserted, # If types are different, conversion Op will be inserted,
# and it may trigger an infinite loop. # and it may trigger an infinite loop.
and replace_with_in[idx].type == out.type): and replace_with_in[idx].type == out.type):
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]] clean_to_replace.append(out)
clean_replace_with_out += [replace_with_out[idx]] clean_replace_with_in.append(replace_with_in[idx])
clean_replace_with_out.append(replace_with_out[idx])
if len(clean_to_replace) > 0: if len(clean_to_replace) > 0:
# We can finally put an end to all this madness # We can finally put an end to all this madness
...@@ -542,8 +577,9 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -542,8 +577,9 @@ class PushOutSeqScan(gof.Optimizer):
if isinstance(repl_out, theano.Constant): if isinstance(repl_out, theano.Constant):
repl_in = repl_out.clone() repl_in = repl_out.clone()
else: else:
nw_inner += [repl_in] nw_inner.append(repl_in)
nw_outer += [repl_out] nw_outer.append(repl_out)
givens[to_repl] = repl_in givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs, _op_outs = scan_utils.clone(clean_outputs,
...@@ -568,7 +604,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -568,7 +604,7 @@ class PushOutSeqScan(gof.Optimizer):
not op.outer_mitmot(node)): not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = OrderedDict() replace_with = OrderedDict()
for idx, out in enumerate(to_replace): for out, idx in to_replace_map.iteritems():
if out in local_fgraph.outputs: if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)] x = node.outputs[local_fgraph.outputs.index(out)]
_y = replace_with_out[idx] _y = replace_with_out[idx]
...@@ -631,15 +667,15 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -631,15 +667,15 @@ class PushOutScanOutput(gof.Optimizer):
# Use scan_args to parse the inputs and outputs of scan for ease of # Use scan_args to parse the inputs and outputs of scan for ease of
# use # use
args = scan_args(node.inputs, node.outputs, args = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info) op.inputs, op.outputs, op.info)
local_fgraph = gof.FunctionGraph(args.inner_inputs, local_fgraph = gof.FunctionGraph(args.inner_inputs,
args.inner_outputs, args.inner_outputs,
clone=False) clone=False)
new_scan_node = None new_scan_node = None
local_fgraph_topo = local_fgraph.toposort()
for nd in local_fgraph.toposort(): for nd in local_fgraph_topo:
if (isinstance(nd.op, theano.tensor.Dot) and if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in args.inner_out_nit_sot): nd.out in args.inner_out_nit_sot):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论