提交 7d2dc922 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add tests and modify the optimization to cover new cases

上级 4258ebb1
......@@ -627,6 +627,9 @@ class PushOutScanOutput(gof.Optimizer):
outer_non_seqs = op.outer_non_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
new_scan_node = None
for nd in local_fgraph.toposort():
......@@ -693,7 +696,8 @@ class PushOutScanOutput(gof.Optimizer):
(nd.inputs[1] in inner_non_seqs or
isinstance(nd.inputs[1], tensor.Constant)) and
nd.inputs[0].ndim == 1 and
nd.inputs[0] not in clean_inputs):
(nd.inputs[0] in inner_seqs or
nd.inputs[0] not in clean_inputs)):
valid_inputs = True
idx_matrix_input = 1
......@@ -724,7 +728,11 @@ class PushOutScanOutput(gof.Optimizer):
# scan, get a reference to the corresponding outer output.
# Otherwise, add it as a new nit_sot output and then get a
# reference to it
if nd.inputs[idx_vector_input] in nitsot_outs:
if nd.inputs[idx_vector_input] in inner_seqs:
_idx = inner_seqs.index(nd.inputs[idx_vector_input])
outer_vector_input = outer_seqs[_idx]
elif nd.inputs[idx_vector_input] in nitsot_outs:
# Figure out which scan output corresponds the vector
# input
inner_vector_input = nd.inputs[idx_vector_input]
......@@ -741,6 +749,9 @@ class PushOutScanOutput(gof.Optimizer):
new_output_inner)
outer_vector_input = new_scan_node.outputs[idx_new_output]
node = new_scan_node
idx_dot_output = idx_old_outputs[idx_dot_output]
# Perform the Dot outside of scan
if idx_matrix_input == 0:
outer_dot_inputs = [outer_matrix_input,
......@@ -752,9 +763,8 @@ class PushOutScanOutput(gof.Optimizer):
outer_dot_output = theano.tensor.dot(*outer_dot_inputs)
# Modify the outer graph to add the outer Dot
new_idx_dot_output = idx_old_outputs[idx_dot_output]
fgraph.replace_all([
(new_scan_node.outputs[new_idx_dot_output],
(node.outputs[idx_dot_output],
outer_dot_output)],
reason="scanOp_pushout_output")
......
......@@ -4,6 +4,7 @@ import unittest
import theano
from theano import config
from theano import tensor as T
from theano.scan_module.scan_op import Scan
from theano.tests import unittest_tools as utt
mode = theano.compile.mode.get_mode(config.mode)
......@@ -134,3 +135,140 @@ class GaussNewtonMatrix(object):
# apply Tikhonov damping
JHJv = [JHJvi + damp * vi for JHJvi, vi in zip(JHJv, v)]
return JHJv
class TestPushOutScanOutputDot(object):
"""
Test class for the PushOutScanOutput optimizer in the case where the inner
function of a scan op has an output which is the result of a Dot product
on a non-sequence matrix input to scan and a vector that is the result of
computation in the inner function.
"""
def test_dot_not_output(self):
"""
Test the case where the vector input to the dot is not already an
output of the inner function.
"""
v = T.vector()
m = T.matrix()
output = T.dot(v, m)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([v, m], T.jacobian(output, v))
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([v, m], T.jacobian(output, v),
mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
v_value = numpy.random.random((4))
m_value = numpy.random.random((4, 5))
output_opt = f_opt(v_value, m_value)
output_no_opt = f_no_opt(v_value, m_value)
utt.assert_allclose(output_opt, output_no_opt)
def test_dot_nitsot_output(self):
"""
Test the case where the vector input to the dot is already a nitsot
output of the inner function.
"""
a = T.matrix()
b = T.matrix()
def inner_fct(vect, mat):
vect_squared = vect ** 2
return T.dot(vect_squared, mat), vect_squared
outputs, updates = theano.scan(fn=inner_fct,
outputs_info=[None]*2,
sequences=a,
non_sequences=b)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([a, b], outputs)
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([a, b], outputs, mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
a_value = numpy.random.random((3, 4))
b_value = numpy.random.random((4, 5))
output_opt = f_opt(a_value, b_value)
output_no_opt = f_no_opt(a_value, b_value)
utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1])
def test_dot_sitsot_output(self):
"""
Test the case where the vector input to the dot is not already a
non-nitsot (in this case a sitsot) output of the inner function.
"""
a = T.matrix()
b = T.matrix()
def inner_fct(seq1, previous_output1, nonseq1):
output1 = previous_output1 + seq1
output2 = T.dot(output1, nonseq1)
return output1, output2
outputs, updates = theano.scan(fn=inner_fct,
outputs_info=[a[0], None],
sequences=a,
non_sequences=b)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([a, b], outputs)
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([a, b], outputs, mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
a_value = numpy.random.random((3, 4))
b_value = numpy.random.random((4, 5))
output_opt = f_opt(a_value, b_value)
output_no_opt = f_no_opt(a_value, b_value)
utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论