提交 f3f04482 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Initial version of optimization

上级 7fb90052
......@@ -583,6 +583,221 @@ class PushOutSeqScan(gof.Optimizer):
return False
class PushOutScanOutput(gof.Optimizer):
"""
This optimization can push operations performed at the end of the inner
graph of scan to outside of scan
"""
def __init__(self):
gof.Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort()
if isinstance(x.op, scan_op.Scan)]
for node in nodelist:
#self.process_node(fgraph, node)
print "Pick a new node"
# Process the node as long as something gets optimized
while node != None:
print "Process the node"
node = self.process_node(fgraph, node)
def process_node(self, fgraph, node):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
op = node.op
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs)
new_scan_node = None
for nd in local_fgraph.toposort():
if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in clean_outputs):
"""
The following optimization involves pushing out, after the
can, a Dot where one input is one of scan's input with ndim=2
and the other is an intermediate variable in the Scan inner
graph with ndim=1.
The Dot product is pushed out of the scan and its inputs are
now the original matrix and a new matrix obtained by
concatenating the vectors into a matrix.
"""
# Ensure that the output of the Dot is used somewhere
# in the outer graph
idx_dot_output = clean_outputs.index(nd.out)
if len(node.outputs[idx_dot_output].clients) == 0:
# The Dot's output is not used. It is not worth performing
# the optimization. Move on to the next node
continue
"""
Validate that one of the inputs is a matrix AND a
non-sequence input to scan and that the other input is a
vector and neither an input nor an output.
"""
valid_inputs = False
idx_matrix_input = -1
idx_vector_input = -1
if (nd.inputs[0].ndim == 2 and
(nd.inputs[0] in inner_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and
nd.inputs[1] not in clean_inputs and
nd.inputs[1] not in clean_outputs):
valid_inputs = True
idx_matrix_input = 0
idx_vector_input = 1
elif (nd.inputs[1].ndim == 2 and
(nd.inputs[1] in inner_non_seqs or
isinstance(nd.inputs[1], tensor.Constant)) and
nd.inputs[0].ndim == 1 and
nd.inputs[0] not in clean_inputs and
nd.inputs[0] not in clean_outputs):
valid_inputs = True
idx_matrix_input = 1
idx_vector_input = 0
if valid_inputs:
# The optimization can be applied on the current Dot
# Create a copy of the Dot's matrix input outside
# of scan
inner_matrix_input = nd.inputs[idx_matrix_input]
if inner_matrix_input in inner_non_seqs:
_idx = inner_non_seqs.index(inner_matrix_input)
outer_matrix_input = outer_non_seqs[_idx]
elif isinstance(inner_matrix_input, theano.Constant):
outer_matrix_input = inner_matrix_input.clone()
else:
# Should not have happened
raise Exception(
('Error in the `scan_pushout_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'),
inner_matrix_input)
# Add the new outputs to the scan (get as output the variables of
# the outer graph corresponding to the new scan outputs
new_output_inner = nd.inputs[idx_vector_input]
_new_scan_node, idx_old_outputs, idx_new_output = self.add_nitsot_outputs(
fgraph, node,
clean_inputs,
clean_outputs,
new_output_inner)
new_outer_output = _new_scan_node.outputs[idx_new_output]
# Perform the Dot on the new scan output.
if idx_matrix_input == 0:
outer_dot_inputs = [outer_matrix_input,
new_outer_output]
else: # idx_matrix_input == 1
outer_dot_inputs = [new_outer_output,
outer_matrix_input]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot
new_idx_dot_output = idx_old_outputs[idx_dot_output]
fgraph.replace_all([
(_new_scan_node.outputs[new_idx_dot_output],
outer_dot_output)],
reason="scanOp_pushout_output")
new_scan_node = _new_scan_node
break
return new_scan_node
def add_nitsot_outputs(self, fgraph, scan_node, clean_inputs,
clean_outputs, new_output_inner):
"""
Create a new scan that takes the same inputs as scan_node and produces
the same output as well as the provided output new_output_inner
"""
# Compute the index at which to insert the new output. For a scan Op,
# the outputs the ordering : mit_mot, mit_sot, sis_sot, nit_sot and
# shared_outs
output_insert_idx = (scan_node.op.info['n_mit_mot'] +
scan_node.op.info['n_mit_sot'] +
scan_node.op.info['n_sit_sot'] +
scan_node.op.info['n_nit_sot'])
# Compile list of new inputs and outputs for the new Scan op
_nw_op_ins = clean_inputs
_nw_op_outs = (scan_utils.clone(clean_outputs[:output_insert_idx]) +
[new_output_inner] +
scan_utils.clone(clean_outputs[output_insert_idx:]))
nw_op_ins, nw_op_outs = scan_utils.reconstruct_graph(_nw_op_ins,
_nw_op_outs)
# Compile a list containing, for every output of the old scan op,
# what its output index will be under the new scan op
nw_op_output_indices = [i + int(i>output_insert_idx)
for i in range(output_insert_idx)]
# Construct the new Scan op
nw_info = scan_node.op.info.copy()
nw_info['n_nit_sot'] += 1
nw_scan = scan_op.Scan(nw_op_ins, nw_op_outs, nw_info)
# Assemble the lists of inputs for the node that will apply the new
# scan op by inserting an initial value for the new input in the
# at the right position in the list of inputs for the old node.
nw_node_input_idx = (scan_node.op.info['n_seqs'] +
scan_node.op.info['n_mit_mot'] +
scan_node.op.info['n_mit_sot'] +
scan_node.op.info['n_sit_sot'] +
scan_node.op.info['n_shared_outs'] +
scan_node.op.info['n_nit_sot'])
# (the initial value is the nb of taps to feed back as inputs to the
# next iteration of Scan's inner graph. Use 0 for a nit_sot output.)
nw_node_input_init_value = tensor.as_tensor_variable(0)
nw_node_inputs = (scan_node.inputs[:nw_node_input_idx] +
[nw_node_input_init_value] +
scan_node.inputs[nw_node_input_idx:])
# Build the Scan's apply node
nw_node = nw_scan(*nw_node_inputs, **dict(return_list=True))[0].owner
nw_node_old_outputs = (nw_node.outputs[:output_insert_idx] +
nw_node.outputs[output_insert_idx+1:])
# Make sure the outputs of the new scan op are used instead of the old
fgraph.replace_all(
zip(scan_node.outputs, nw_node_old_outputs),
reason='scanOp_pushout_output')
return nw_node, nw_op_output_indices, output_insert_idx
class ScanInplaceOptimizer(Optimizer):
"""Graph optimizer for Scan(makes it run inplace)"""
def __init__(self, typeConstructor=None, gpu_flag=False, gpua_flag=False):
......@@ -1800,6 +2015,14 @@ scan_seqopt1.register('scan_pushout_dot1',
'scan')
scan_seqopt1.register('scanOp_pushout_output',
PushOutScanOutput(),
5,
'fast_run',
'more_mem',
'scan')
scan_eqopt2.register('constant_folding_for_scan2',
opt.in2out(tensor.opt.constant_folding,
ignore_newtrees=True),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论