提交 ea54556b authored 作者: abergeron's avatar abergeron

Merge pull request #2248 from carriepl/scan_push_out_dot

Scan opt to push out Dot products outside of Scan's inner graph
...@@ -2548,7 +2548,7 @@ class T_Scan(unittest.TestCase): ...@@ -2548,7 +2548,7 @@ class T_Scan(unittest.TestCase):
x = theano.tensor.fmatrix('x') x = theano.tensor.fmatrix('x')
mem_val = numpy.zeros((2,), dtype='float32') mem_val = numpy.zeros((2,), dtype='float32')
memory = theano.shared(mem_val.copy()) memory = theano.shared(mem_val)
W = theano.shared(numpy.random.random((5, 2)).astype('float32')) W = theano.shared(numpy.random.random((5, 2)).astype('float32'))
def f(inp, mem): def f(inp, mem):
...@@ -2566,7 +2566,7 @@ class T_Scan(unittest.TestCase): ...@@ -2566,7 +2566,7 @@ class T_Scan(unittest.TestCase):
x_val = numpy.random.random((4, 3)).astype('float32') x_val = numpy.random.random((4, 3)).astype('float32')
f_vals = f(x_val) f_vals = f(x_val)
memory.set_value(mem_val.copy()) memory.set_value(mem_val)
f2_vals = f2(x_val) f2_vals = f2(x_val)
utt.assert_allclose(f_vals, f2_vals) utt.assert_allclose(f_vals, f2_vals)
......
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
import theano import theano
from theano import config from theano import config
from theano import tensor as T from theano import tensor as T
from theano.scan_module.scan_op import Scan
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
mode = theano.compile.mode.get_mode(config.mode) mode = theano.compile.mode.get_mode(config.mode)
...@@ -134,3 +135,140 @@ class GaussNewtonMatrix(object): ...@@ -134,3 +135,140 @@ class GaussNewtonMatrix(object):
# apply Tikhonov damping # apply Tikhonov damping
JHJv = [JHJvi + damp * vi for JHJvi, vi in zip(JHJv, v)] JHJv = [JHJvi + damp * vi for JHJvi, vi in zip(JHJv, v)]
return JHJv return JHJv
class TestPushOutScanOutputDot(object):
"""
Test class for the PushOutScanOutput optimizer in the case where the inner
function of a scan op has an output which is the result of a Dot product
on a non-sequence matrix input to scan and a vector that is the result of
computation in the inner function.
"""
def test_dot_not_output(self):
"""
Test the case where the vector input to the dot is not already an
output of the inner function.
"""
v = T.vector()
m = T.matrix()
output = T.dot(v, m)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([v, m], T.jacobian(output, v))
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([v, m], T.jacobian(output, v),
mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
v_value = numpy.random.random((4))
m_value = numpy.random.random((4, 5))
output_opt = f_opt(v_value, m_value)
output_no_opt = f_no_opt(v_value, m_value)
utt.assert_allclose(output_opt, output_no_opt)
def test_dot_nitsot_output(self):
"""
Test the case where the vector input to the dot is already a nitsot
output of the inner function.
"""
a = T.matrix()
b = T.matrix()
def inner_fct(vect, mat):
vect_squared = vect ** 2
return T.dot(vect_squared, mat), vect_squared
outputs, updates = theano.scan(fn=inner_fct,
outputs_info=[None]*2,
sequences=a,
non_sequences=b)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([a, b], outputs)
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([a, b], outputs, mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
a_value = numpy.random.random((3, 4))
b_value = numpy.random.random((4, 5))
output_opt = f_opt(a_value, b_value)
output_no_opt = f_no_opt(a_value, b_value)
utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1])
def test_dot_sitsot_output(self):
"""
Test the case where the vector input to the dot is not already a
non-nitsot (in this case a sitsot) output of the inner function.
"""
a = T.matrix()
b = T.matrix()
def inner_fct(seq1, previous_output1, nonseq1):
output1 = previous_output1 + seq1
output2 = T.dot(output1, nonseq1)
return output1, output2
outputs, updates = theano.scan(fn=inner_fct,
outputs_info=[a[0], None],
sequences=a,
non_sequences=b)
# Compile the function twice, once with the optimization and once
# without
f_opt = theano.function([a, b], outputs)
default_mode = theano.compile.get_default_mode()
default_mode.excluding("scanOp_pushout_output")
f_no_opt = theano.function([a, b], outputs, mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
# the same results as the function compiled without
a_value = numpy.random.random((3, 4))
b_value = numpy.random.random((4, 5))
output_opt = f_opt(a_value, b_value)
output_no_opt = f_no_opt(a_value, b_value)
utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论