提交 ea54556b authored 作者: abergeron's avatar abergeron

Merge pull request #2248 from carriepl/scan_push_out_dot

Scan opt to push out Dot products outside of Scan's inner graph
......@@ -583,6 +583,263 @@ class PushOutSeqScan(gof.Optimizer):
return False
class PushOutScanOutput(gof.Optimizer):
"""
This optimization can push operations performed at the end of the inner
graph of scan to outside of scan
"""
def __init__(self):
gof.Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort()
if isinstance(x.op, scan_op.Scan)]
for node in nodelist:
# Process the node as long as something gets optimized
while node != None:
node = self.process_node(fgraph, node)
def process_node(self, fgraph, node):
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs, clone=False)
op = node.op
# Use scan_args to parse the inputs and outputs of scan for ease of
# use
args = scan_args(node.inputs, node.outputs,
node.op.inputs, node.op.outputs, node.op.info)
# Obtain the list containing the indices, in clean_outputs, of the
# scan op's outputs that are nit_sot (not fed back to the inner fct.)
nitsot_outs = op.inner_nitsot_outs(node.outputs)
idx_nitsot_outs = [node.outputs.index(i) for i in nitsot_outs]
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
new_scan_node = None
for nd in local_fgraph.toposort():
if (isinstance(nd.op, theano.tensor.Dot) and
nd.out in clean_outputs):
"""
The following optimization involves pushing out, after the
scan, a Dot where one input is one of scan's input with ndim=2
and the other is an intermediate variable in the Scan inner
graph with ndim=1.
The Dot product is pushed out of the scan and its inputs are
now the original matrix and a new matrix obtained by
concatenating the vectors into a matrix.
"""
# Go through clean_outputs and pick one that is
# - Equal to the output of the tensor.Dot
# - Nit_sot : not fed back to the inner graph because applying
# the optimization in that case would alter the results of
# the function
# - Used by something outside of the graph to avoid applying
# the optimization needlessly
idx_dot_output = -1
for i in range(len(clean_outputs)):
is_dot_output = (nd.out == clean_outputs[i])
is_nitsot_output = i in idx_nitsot_outs
used_in_outer_graph = (len(node.outputs[i].clients) > 0)
if (is_dot_output and is_nitsot_output and
used_in_outer_graph):
idx_dot_output = i
break
if idx_dot_output == -1:
# The dot has no output that fits the requirements for
# this optimization. Move on to the next node.
continue
"""
Validate that one of the inputs is a matrix AND a
non-sequence input to scan and that the other input is a
vector and either an sequence input to scan or the result
of computation in the inner function of scan.
"""
valid_inputs = False
idx_matrix_input = -1
idx_vector_input = -1
if (nd.inputs[0].ndim == 2 and
(nd.inputs[0] in inner_non_seqs or
isinstance(nd.inputs[0], tensor.Constant)) and
nd.inputs[1].ndim == 1 and
(nd.inputs[1] in inner_seqs or
nd.inputs[1] not in clean_inputs)):
valid_inputs = True
idx_matrix_input = 0
idx_vector_input = 1
elif (nd.inputs[1].ndim == 2 and
(nd.inputs[1] in inner_non_seqs or
isinstance(nd.inputs[1], tensor.Constant)) and
nd.inputs[0].ndim == 1 and
(nd.inputs[0] in inner_seqs or
nd.inputs[0] not in clean_inputs)):
valid_inputs = True
idx_matrix_input = 1
idx_vector_input = 0
if valid_inputs:
# The optimization can be applied on the current Dot
# Create a copy of the Dot's matrix input outside
# of scan
inner_matrix_input = nd.inputs[idx_matrix_input]
if inner_matrix_input in inner_non_seqs:
_idx = inner_non_seqs.index(inner_matrix_input)
outer_matrix_input = outer_non_seqs[_idx]
elif isinstance(inner_matrix_input, theano.Constant):
outer_matrix_input = inner_matrix_input.clone()
else:
# Should not have happened
raise Exception(
('Error in the `scan_pushout_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'),
inner_matrix_input)
# If the vector_input is already a nit_sot output of the
# scan, get a reference to the corresponding outer output.
# Otherwise, add it as a new nit_sot output and then get a
# reference to it
if nd.inputs[idx_vector_input] in inner_seqs:
_idx = inner_seqs.index(nd.inputs[idx_vector_input])
outer_vector_input = outer_seqs[_idx]
elif nd.inputs[idx_vector_input] in nitsot_outs:
# Figure out which scan output corresponds the vector
# input
inner_vector_input = nd.inputs[idx_vector_input]
vector_input_nitsot_idx = args.inner_out_nit_sot.index(inner_vector_input)
outer_vector_input = args.outer_out_nit_sot[vector_input_nitsot_idx]
else:
# Add the vector_input as a new nitsot output to scan
new_output_inner = nd.inputs[idx_vector_input]
new_scan_node, idx_old_outputs, idx_new_output = self.add_nitsot_outputs(
fgraph, node,
clean_inputs,
clean_outputs,
new_output_inner)
outer_vector_input = new_scan_node.outputs[idx_new_output]
node = new_scan_node
idx_dot_output = idx_old_outputs[idx_dot_output]
# Perform the Dot outside of scan
if idx_matrix_input == 0:
outer_dot_inputs = [outer_vector_input,
outer_matrix_input.transpose()]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
else: # idx_matrix_input == 1
outer_dot_inputs = [outer_vector_input,
outer_matrix_input]
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot
fgraph.replace_all([
(node.outputs[idx_dot_output],
outer_dot_output)],
reason="scanOp_pushout_output")
break
return new_scan_node
def add_nitsot_outputs(self, fgraph, scan_node, clean_inputs,
clean_outputs, new_output_inner):
"""
Create a new scan that takes the same inputs as scan_node and produces
the same output as well as the provided output new_output_inner
"""
# Compute the index at which to insert the new output. For a scan Op,
# the outputs follow the ordering : mit_mot, mit_sot, sis_sot, nit_sot
# and shared_outs
output_insert_idx = (scan_node.op.info['n_mit_mot'] +
scan_node.op.info['n_mit_sot'] +
scan_node.op.info['n_sit_sot'] +
scan_node.op.info['n_nit_sot'])
# Compile list of new inputs and outputs for the new Scan op
_nw_op_ins = clean_inputs
_nw_op_outs = (scan_utils.clone(clean_outputs[:output_insert_idx]) +
[new_output_inner] +
scan_utils.clone(clean_outputs[output_insert_idx:]))
nw_op_ins, nw_op_outs = scan_utils.reconstruct_graph(_nw_op_ins,
_nw_op_outs)
# Compile a list containing, for every output of the old scan op,
# what its output index will be under the new scan op
nw_op_output_indices = [i + int(i>output_insert_idx)
for i in range(output_insert_idx)]
# Construct the new Scan op
nw_info = scan_node.op.info.copy()
nw_info['n_nit_sot'] += 1
nw_scan = scan_op.Scan(nw_op_ins, nw_op_outs, nw_info)
# Assemble the lists of inputs for the node that will apply the new
# scan op by inserting an initial value for the new input in the
# at the right position in the list of inputs for the old node.
nw_node_input_idx = (scan_node.op.info['n_seqs'] +
scan_node.op.info['n_mit_mot'] +
scan_node.op.info['n_mit_sot'] +
scan_node.op.info['n_sit_sot'] +
scan_node.op.info['n_shared_outs'] +
scan_node.op.info['n_nit_sot'])
# (the initial value is the nb of steps to store. For a nistot,
# it should be the number of steps performed by scan)
nw_node_input_init_value = scan_node.inputs[0]
nw_node_inputs = (scan_node.inputs[:nw_node_input_idx] +
[nw_node_input_init_value] +
scan_node.inputs[nw_node_input_idx:])
# Build the Scan's apply node
nw_node = nw_scan(*nw_node_inputs, **dict(return_list=True))[0].owner
nw_node_old_outputs = (nw_node.outputs[:output_insert_idx] +
nw_node.outputs[output_insert_idx+1:])
# Make sure the outputs of the new scan op are used instead of the old
fgraph.replace_all(
zip(scan_node.outputs, nw_node_old_outputs),
reason='scanOp_pushout_output')
return nw_node, nw_op_output_indices, output_insert_idx
class ScanInplaceOptimizer(Optimizer):
"""Graph optimizer for Scan(makes it run inplace)"""
def __init__(self, typeConstructor=None, gpu_flag=False, gpua_flag=False):
......@@ -1797,6 +2054,14 @@ scan_seqopt1.register('scan_pushout_dot1',
'scan')
scan_seqopt1.register('scanOp_pushout_output',
PushOutScanOutput(),
5,
'fast_run',
'more_mem',
'scan')
scan_eqopt2.register('constant_folding_for_scan2',
opt.in2out(tensor.opt.constant_folding,
ignore_newtrees=True),
......
......@@ -2548,7 +2548,7 @@ class T_Scan(unittest.TestCase):
x = theano.tensor.fmatrix('x')
mem_val = numpy.zeros((2,), dtype='float32')
memory = theano.shared(mem_val.copy())
memory = theano.shared(mem_val)
W = theano.shared(numpy.random.random((5, 2)).astype('float32'))
def f(inp, mem):
......@@ -2557,8 +2557,8 @@ class T_Scan(unittest.TestCase):
return d, d
outs, updts = theano.scan(f, sequences=[x],
non_sequences=[],
outputs_info=[None, memory])
non_sequences=[],
outputs_info=[None, memory])
f = theano.function([x], outs[0])
f2 = theano.function([x], outs[1])
......@@ -2566,7 +2566,7 @@ class T_Scan(unittest.TestCase):
x_val = numpy.random.random((4, 3)).astype('float32')
f_vals = f(x_val)
memory.set_value(mem_val.copy())
memory.set_value(mem_val)
f2_vals = f2(x_val)
utt.assert_allclose(f_vals, f2_vals)
......
......@@ -4,6 +4,7 @@ import unittest
import theano
from theano import config
from theano import tensor as T
from theano.scan_module.scan_op import Scan
from theano.tests import unittest_tools as utt
mode = theano.compile.mode.get_mode(config.mode)
......@@ -134,3 +135,140 @@ class GaussNewtonMatrix(object):
# apply Tikhonov damping
JHJv = [JHJvi + damp * vi for JHJvi, vi in zip(JHJv, v)]
return JHJv
class TestPushOutScanOutputDot(object):
"""
Test class for the PushOutScanOutput optimizer in the case where the inner
function of a scan op has an output which is the result of a Dot product
on a non-sequence matrix input to scan and a vector that is the result of
computation in the inner function.
"""
def test_dot_not_output(self):
"""
Test the case where the vector input to the dot is not already an
output of the inner function.
"""
v = T.vector()
m = T.matrix()
output = T.dot(v, m)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([v, m], T.jacobian(output, v))
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([v, m], T.jacobian(output, v),
mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
v_value = numpy.random.random((4))
m_value = numpy.random.random((4, 5))
output_opt = f_opt(v_value, m_value)
output_no_opt = f_no_opt(v_value, m_value)
utt.assert_allclose(output_opt, output_no_opt)
def test_dot_nitsot_output(self):
"""
Test the case where the vector input to the dot is already a nitsot
output of the inner function.
"""
a = T.matrix()
b = T.matrix()
def inner_fct(vect, mat):
vect_squared = vect ** 2
return T.dot(vect_squared, mat), vect_squared
outputs, updates = theano.scan(fn=inner_fct,
outputs_info=[None]*2,
sequences=a,
non_sequences=b)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([a, b], outputs)
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([a, b], outputs, mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
a_value = numpy.random.random((3, 4))
b_value = numpy.random.random((4, 5))
output_opt = f_opt(a_value, b_value)
output_no_opt = f_no_opt(a_value, b_value)
utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1])
def test_dot_sitsot_output(self):
"""
Test the case where the vector input to the dot is not already a
non-nitsot (in this case a sitsot) output of the inner function.
"""
a = T.matrix()
b = T.matrix()
def inner_fct(seq1, previous_output1, nonseq1):
output1 = previous_output1 + seq1
output2 = T.dot(output1, nonseq1)
return output1, output2
outputs, updates = theano.scan(fn=inner_fct,
outputs_info=[a[0], None],
sequences=a,
non_sequences=b)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([a, b], outputs)
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([a, b], outputs, mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
a_value = numpy.random.random((3, 4))
b_value = numpy.random.random((4, 5))
output_opt = f_opt(a_value, b_value)
output_no_opt = f_no_opt(a_value, b_value)
utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论