提交 853bbb10 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add unit tests for pushing out a sum-of-dot from Scan

上级 6f5a06d8
...@@ -272,3 +272,142 @@ class TestPushOutScanOutputDot(object): ...@@ -272,3 +272,142 @@ class TestPushOutScanOutputDot(object):
utt.assert_allclose(output_opt[0], output_no_opt[0]) utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1]) utt.assert_allclose(output_opt[1], output_no_opt[1])
class TestPushOutSumOfDot():
"""
Test case for the PushOutScanOutput optimizer in the case where the scan
is used to compute the sum over the dot products between the corresponding
elements of two list of matrices.
"""
def test_machine_translation(self):
"""
This test case comes from https://github.com/rizar/scan-grad-speed and
is an example of actual computation done with scan in the context of
machine translation
"""
# Parameters from an actual machine tranlation run
batch_size = 80
seq_len = 50
n_words = 80 * 50
dim = 1000
# Weight matrices
U = theano.shared(numpy.random.normal(size=(dim, dim), scale=0.0001).astype("float32"))
U.name = 'U'
V = theano.shared(U.get_value())
V.name = 'V'
W = theano.shared(U.get_value())
W.name = 'W'
# Variables and their values
x = T.ftensor3('x')
x_value = numpy.random.normal(size=(seq_len, batch_size, dim), scale=0.0001).astype("float32")
ri = T.ftensor3('ri')
ri_value = x_value
zi = T.ftensor3('zi')
zi_value = x_value
init = T.alloc(numpy.float32(0), batch_size, dim)
def rnn_step1(
# sequences
x, ri, zi,
# outputs_info
h):
pre_r = ri + h.dot(U)
pre_z = zi + h.dot(V)
r = T.nnet.sigmoid(pre_r)
z = T.nnet.sigmoid(pre_z)
after_r = r * h
pre_h = x + after_r.dot(W)
new_h = T.tanh(pre_h)
res_h = z * new_h + (1 - z) * h
return res_h
# Compile the function twice, once with the optimization and once
# without
h, _ = theano.scan(rnn_step1, sequences=[x, ri, zi], n_steps=seq_len,
outputs_info=init, name='fpass1')
cost = h[-1].sum()
grad1 = T.grad(cost, [U, V, W])
f_opt = theano.function(inputs=[x, ri, zi], outputs=grad1)
default_mode = theano.compile.get_default_mode()
new_mode = default_mode.excluding("scanOp_pushout_output")
h, _ = theano.scan(rnn_step1, sequences=[x, ri, zi], n_steps=seq_len,
outputs_info=init, name='fpass1', mode=new_mode)
cost = h[-1].sum()
grad1 = T.grad(cost, [U, V, W])
f_no_opt = theano.function(inputs=[x, ri, zi], outputs=grad1, mode=new_mode)
# Validate that the optimization has been applied
scan_node_grad = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][1]
for output in scan_node_grad.op.outputs:
assert not (isinstance(output.owner.op, T.elemwise.Elemwise) and
any([isinstance(i, T.Dot) for i
in output.owner.inputs]))
# Compare the outputs of the two functions on the same input data.
f_opt_output = f_opt(x_value, ri_value, zi_value)
f_no_opt_output = f_no_opt(x_value, ri_value, zi_value)
utt.assert_allclose(f_opt_output, f_no_opt_output)
def test_non_zero_init(self):
"""
Test the case where the initial value for the nitsot output is
non-zero
"""
input1 = T.dtensor3()
input2 = T.dtensor3()
input3 = T.dtensor3()
W = theano.shared(numpy.random.normal(size=(4, 5)))
U = theano.shared(numpy.random.normal(size=(6, 7)))
def inner_fct(seq1, seq2, seq3, previous_output):
temp1 = T.dot(seq1, W) + seq3
temp2 = T.dot(seq2, U)
dot_output = T.dot(temp1, temp2)
return previous_output + dot_output
init = T.as_tensor_variable(numpy.random.normal(size=(3,7)))
# Compile the function twice, once with the optimization and once
# without
h, _ = theano.scan(inner_fct,
sequences=[input1, input2, input3],
outputs_info=init)
output = h[-1]
f_opt = theano.function([input1, input2, input3], output)
default_mode = theano.compile.get_default_mode()
new_mode = default_mode.excluding("scanOp_pushout_output")
h, _ = theano.scan(inner_fct,
sequences=[input1, input2, input3],
outputs_info=init,
mode=new_mode)
output = h[-1]
f_no_opt = theano.function([input1, input2, input3], output,
mode=new_mode)
# Ensure that the optimization has been applied for f_opt
# TODO
# Compare the outputs of the 2 functions
input1_value = numpy.random.random((2, 3, 4))
input2_value = numpy.random.random((2, 5, 6))
input3_value = numpy.random.random((2, 3, 5))
output_opt = f_opt(input1_value, input2_value, input3_value)
output_no_opt = f_no_opt(input1_value, input2_value, input3_value)
utt.assert_allclose(output_opt, output_no_opt)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论