提交 58dc930e authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Refactor optimization PushOutScanOutput

上级 a3870bb2
......@@ -605,11 +605,6 @@ class PushOutScanOutput(gof.Optimizer):
def process_node(self, fgraph, node):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
op = node.op
# Use scan_args to parse the inputs and outputs of scan for ease of
......@@ -617,29 +612,21 @@ class PushOutScanOutput(gof.Optimizer):
args = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
# Obtain the list containing the indices, in clean_outputs, of the
# scan op's outputs that are nit_sot (not fed back to the inner fct.)
nitsot_outs = op.inner_nitsot_outs(node.outputs)
idx_nitsot_outs = [node.outputs.index(i) for i in nitsot_outs]
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
local_fgraph = gof.FunctionGraph(args.inner_inputs,
args.inner_outputs,
clone=False)
new_scan_node = None
for nd in local_fgraph.toposort():
if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in clean_outputs):
nd.out in args.inner_out_nit_sot):
"""
The following optimization involves pushing out, after the
scan, a Dot where one input is one of scan's input with ndim=2
scan, a Dot whose output is nitsot (not feed back to the inner
graph) and where one input is one of scan's input with ndim=2
and the other is an intermediate variable in the Scan inner
graph with ndim=1.
......@@ -648,29 +635,11 @@ class PushOutScanOutput(gof.Optimizer):
concatenating the vectors into a matrix.
"""
# Go through clean_outputs and pick one that is
# - Equal to the output of the tensor.Dot
# - Nit_sot : not fed back to the inner graph because applying
# the optimization in that case would alter the results of
# the function
# - Used by something outside of the graph to avoid applying
# the optimization needlessly
idx_dot_output = -1
for i in range(len(clean_outputs)):
is_dot_output = (nd.out == clean_outputs[i])
is_nitsot_output = i in idx_nitsot_outs
used_in_outer_graph = (len(node.outputs[i].clients) > 0)
if (is_dot_output and is_nitsot_output and
used_in_outer_graph):
idx_dot_output = i
break
if idx_dot_output == -1:
# The dot has no output that fits the requirements for
# this optimization. Move on to the next node.
# Ensure that the output of the Dot is used in the outer
# graph to avoid apply the optimization needlessly
dot_out_nitsot_idx = args.inner_out_nit_sot.index(nd.out)
outer_dot_output = args.outer_out_nit_sot[dot_out_nitsot_idx]
if len(outer_dot_output.clients) == 0:
continue
"""
......@@ -684,75 +653,39 @@ class PushOutScanOutput(gof.Optimizer):
idx_vector_input = -1
if (nd.inputs[0].ndim == 2 and
(nd.inputs[0] in inner_non_seqs or
(nd.inputs[0] in args.inner_in_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and
(nd.inputs[1] in inner_seqs or
nd.inputs[1] not in clean_inputs)):
(nd.inputs[1] in args.inner_in_seqs or
nd.inputs[1] not in args.inner_inputs)):
valid_inputs = True
idx_matrix_input = 0
idx_vector_input = 1
elif (nd.inputs[1].ndim == 2 and
(nd.inputs[1] in inner_non_seqs or
(nd.inputs[1] in args.inner_in_non_seqs or
isinstance(nd.inputs[1], tensor.Constant)) and
nd.inputs[0].ndim == 1 and
(nd.inputs[0] in inner_seqs or
nd.inputs[0] not in clean_inputs)):
(nd.inputs[0] in args.inner_in_seqs or
nd.inputs[0] not in args.inner_inputs)):
valid_inputs = True
idx_matrix_input = 1
idx_vector_input = 0
if valid_inputs:
# The optimization can be applied on the current Dot
# Create a copy of the Dot's matrix input outside
# of scan
inner_matrix_input = nd.inputs[idx_matrix_input]
if inner_matrix_input in inner_non_seqs:
_idx = inner_non_seqs.index(inner_matrix_input)
outer_matrix_input = outer_non_seqs[_idx]
elif isinstance(inner_matrix_input, theano.Constant):
outer_matrix_input = inner_matrix_input.clone()
else:
# Should not have happened
raise Exception(
('Error in the `scan_pushout_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'),
inner_matrix_input)
# If the vector_input is already a nit_sot output of the
# scan, get a reference to the corresponding outer output.
# Otherwise, add it as a new nit_sot output and then get a
# reference to it
if nd.inputs[idx_vector_input] in inner_seqs:
_idx = inner_seqs.index(nd.inputs[idx_vector_input])
outer_vector_input = outer_seqs[_idx]
elif nd.inputs[idx_vector_input] in nitsot_outs:
# Figure out which scan output corresponds the vector
# input
inner_vector_input = nd.inputs[idx_vector_input]
vector_input_nitsot_idx = args.inner_out_nit_sot.index(inner_vector_input)
outer_vector_input = args.outer_out_nit_sot[vector_input_nitsot_idx]
else:
# Add the vector_input as a new nitsot output to scan
new_output_inner = nd.inputs[idx_vector_input]
new_scan_node, idx_old_outputs, idx_new_output = self.add_nitsot_outputs(
fgraph, node,
clean_inputs,
clean_outputs,
new_output_inner)
outer_vector_input = new_scan_node.outputs[idx_new_output]
node = new_scan_node
idx_dot_output = idx_old_outputs[idx_dot_output]
# Move out of scan the two inputs to the Dot
(outer_vars,
new_scan_node,
new_scan_args) = self.push_out_inner_vars(fgraph,
nd.inputs,
node, args)
outer_vector_input = outer_vars[idx_vector_input]
outer_matrix_input = outer_vars[idx_matrix_input]
# Perform the Dot outside of scan
if idx_matrix_input == 0:
......@@ -766,7 +699,7 @@ class PushOutScanOutput(gof.Optimizer):
# Modify the outer graph to add the outer Dot
fgraph.replace_all([
(node.outputs[idx_dot_output],
(new_scan_args.outer_out_nit_sot[dot_out_nitsot_idx],
outer_dot_output)],
reason="scanOp_pushout_output")
......@@ -774,71 +707,135 @@ class PushOutScanOutput(gof.Optimizer):
return new_scan_node
def add_nitsot_outputs(self, fgraph, scan_node, clean_inputs,
clean_outputs, new_output_inner):
def inner_sitsot_only_last_step_used(self, var, scan_args):
"""
Create a new scan that takes the same inputs as scan_node and produces
the same output as well as the provided output new_output_inner
Given a inner nit_sot output of scan, return True iff the outer
nit_sot output has only one client and that client is a Subtensor
instance that takes only the last step (last element along the first
axis).
"""
idx = scan_args.inner_out_sit_sot.index(var)
outer_var = scan_args.outer_out_sit_sot[idx]
# Compute the index at which to insert the new output. For a scan Op,
# the outputs follow the ordering : mit_mot, mit_sot, sis_sot, nit_sot
# and shared_outs
output_insert_idx = (scan_node.op.info['n_mit_mot'] +
scan_node.op.info['n_mit_sot'] +
scan_node.op.info['n_sit_sot'] +
scan_node.op.info['n_nit_sot'])
# Compile list of new inputs and outputs for the new Scan op
_nw_op_ins = clean_inputs
_nw_op_outs = (scan_utils.clone(clean_outputs[:output_insert_idx]) +
[new_output_inner] +
scan_utils.clone(clean_outputs[output_insert_idx:]))
nw_op_ins, nw_op_outs = scan_utils.reconstruct_graph(_nw_op_ins,
_nw_op_outs)
# Compile a list containing, for every output of the old scan op,
# what its output index will be under the new scan op
nw_op_output_indices = [i + int(i>output_insert_idx)
for i in range(output_insert_idx)]
# Construct the new Scan op
nw_info = scan_node.op.info.copy()
nw_info['n_nit_sot'] += 1
nw_scan = scan_op.Scan(nw_op_ins, nw_op_outs, nw_info)
# Assemble the lists of inputs for the node that will apply the new
# scan op by inserting an initial value for the new input in the
# at the right position in the list of inputs for the old node.
nw_node_input_idx = (scan_node.op.info['n_seqs'] +
scan_node.op.info['n_mit_mot'] +
scan_node.op.info['n_mit_sot'] +
scan_node.op.info['n_sit_sot'] +
scan_node.op.info['n_shared_outs'] +
scan_node.op.info['n_nit_sot'])
if len(outer_var.clients) == 1:
# (the initial value is the nb of steps to store. For a nistot,
# it should be the number of steps performed by scan)
nw_node_input_init_value = scan_node.inputs[0]
client = outer_var.clients[0][0]
if (isinstance(client.op, theano.tensor.Subtensor) and
isinstance(client.inputs[1], theano.Constant) and
client.inputs[1].ndim == 0 and
client.inputs[1].value == -1):
nw_node_inputs = (scan_node.inputs[:nw_node_input_idx] +
[nw_node_input_init_value] +
scan_node.inputs[nw_node_input_idx:])
return True
return False
# Build the Scan's apply node
nw_node = nw_scan(*nw_node_inputs, **dict(return_list=True))[0].owner
def get_outer_ndim(self, var, scan_args):
nw_node_old_outputs = (nw_node.outputs[:output_insert_idx] +
nw_node.outputs[output_insert_idx+1:])
# Given a variable, determine the number of dimension it would have if
# it was pushed out of scan
if (var in scan_args.inner_in_non_seqs or
isinstance(var, theano.Constant)):
# Make sure the outputs of the new scan op are used instead of the old
fgraph.replace_all(
zip(scan_node.outputs, nw_node_old_outputs),
reason='scanOp_pushout_output')
outer_ndim = var.ndim
else:
outer_ndim = var.ndim + 1
return outer_ndim
def push_out_inner_vars(self, fgraph, inner_vars, old_scan_node,
old_scan_args):
outer_vars = [None] * len(inner_vars)
new_scan_node = old_scan_node
new_scan_args = old_scan_args
# For the inner_vars that already exist in the outer graph,
# simply obtain a reference to them
for idx in range(len(inner_vars)):
var = inner_vars[idx]
if var in old_scan_args.inner_in_seqs:
idx_seq = old_scan_args.inner_in_seqs.index(var)
outer_vars[idx] = old_scan_args.outer_in_seqs[idx_seq]
elif var in old_scan_args.inner_in_non_seqs:
idx_non_seq = old_scan_args.inner_in_non_seqs.index(var)
outer_vars[idx] = old_scan_args.outer_in_non_seqs[idx_non_seq]
elif isinstance(var, theano.Constant):
outer_vars[idx] = var.clone()
return nw_node, nw_op_output_indices, output_insert_idx
elif var in old_scan_args.inner_out_nit_sot:
idx_nitsot = old_scan_args.inner_out_nit_sot.index(var)
outer_vars[idx] = old_scan_args.outer_out_nit_sot[idx_nitsot]
# For the inner_vars that don't already exist in the outer graph, add
# them as new nitsot outputs to the scan node.
idx_add_as_nitsots = [i for i in range(len(outer_vars))
if outer_vars[i] == None]
add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
if len(add_as_nitsots) > 0:
new_scan_node = self.add_nitsot_outputs(fgraph,old_scan_node,
old_scan_args,
add_as_nitsots)
new_scan_args = scan_args(new_scan_node.inputs,
new_scan_node.outputs,
new_scan_node.op.inputs,
new_scan_node.op.outputs,
new_scan_node.op.info)
new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots):]
for i in range(len(new_outs)):
outer_vars[idx_add_as_nitsots[i]] = new_outs[i]
return outer_vars, new_scan_node, new_scan_args
def add_nitsot_outputs(self, fgraph, old_scan_node,
old_scan_args, new_outputs_inner):
nb_new_outs = len(new_outputs_inner)
# Create the initial values for the new nitsot outputs
# (the initial value is the nb of steps to store. For a nistot,
# it should be the number of steps performed by scan)
new_nitsots_initial_value = [old_scan_node.inputs[0]
for i in range(nb_new_outs)]
# Create the scan_args corresponding to the new scan op to
# create
new_scan_args = copy.copy(old_scan_args)
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
# Create the scan op from the scan_args
new_scan_op = scan_op.Scan(new_scan_args.inner_inputs,
new_scan_args.inner_outputs,
new_scan_args.info)
# Create the Apply node for the scan op
new_scan_node = new_scan_op(*new_scan_args.outer_inputs,
**dict(return_list=True))[0].owner
# Modify the outer graph to make sure the outputs of the new scan are
# used instead of the outputs of the old scan
new_node_new_outputs_idx = (len(old_scan_args.outer_outputs) -
len(old_scan_args.outer_out_shared))
new_node_old_outputs = (
new_scan_node.outputs[:new_node_new_outputs_idx] +
new_scan_node.outputs[new_node_new_outputs_idx+nb_new_outs:])
fgraph.replace_all_validate_remove(
zip(old_scan_node.outputs, new_node_old_outputs),
remove=[old_scan_node],
reason='scanOp_pushout_output')
return new_scan_node
class ScanInplaceOptimizer(Optimizer):
"""Graph optimizer for Scan(makes it run inplace)"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论