提交 06cc52d7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2323 from carriepl/scan_push_out_sum_of_dot

Scan push out sum of dot
...@@ -1427,6 +1427,11 @@ class Scan(PureOp): ...@@ -1427,6 +1427,11 @@ class Scan(PureOp):
else: else:
grad_steps = inputs[0] grad_steps = inputs[0]
# Restrict the number of grad steps according to
# self.truncate_gradient
if self.truncate_gradient != -1:
grad_steps = tensor.minimum(grad_steps, self.truncate_gradient)
rval = scan_utils.reconstruct_graph(self.inputs, rval = scan_utils.reconstruct_graph(self.inputs,
self.outputs) self.outputs)
self_inputs = rval[0] self_inputs = rval[0]
...@@ -1652,6 +1657,10 @@ class Scan(PureOp): ...@@ -1652,6 +1657,10 @@ class Scan(PureOp):
outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)] outer_inp_seqs += [x[::-1][:-1] for x in self.outer_sitsot_outs(outs)]
outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)] outer_inp_seqs += [x[::-1] for x in self.outer_nitsot_outs(outs)]
# Restrict the length of the outer sequences to the number of grad
# steps
outer_inp_seqs = [seq[:grad_steps] for seq in outer_inp_seqs]
inner_inp_seqs = self.inner_seqs(self_inputs) inner_inp_seqs = self.inner_seqs(self_inputs)
inner_inp_seqs += self.inner_mitmot(self_inputs) inner_inp_seqs += self.inner_mitmot(self_inputs)
inner_inp_seqs += self.inner_mitsot(self_inputs) inner_inp_seqs += self.inner_mitsot(self_inputs)
...@@ -1820,9 +1829,6 @@ class Scan(PureOp): ...@@ -1820,9 +1829,6 @@ class Scan(PureOp):
ins_pos += 1 ins_pos += 1
n_mitmot_inps += 2 n_mitmot_inps += 2
if self.truncate_gradient != -1:
grad_steps = tensor.minimum(grad_steps, self.truncate_gradient)
n_nit_sot = self.n_seqs n_nit_sot = self.n_seqs
inner_out_nitsot = dC_dinps_t[:self.n_seqs] inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:] inner_out_sitsot = dC_dinps_t[ins_pos:]
......
...@@ -157,12 +157,11 @@ class TestPushOutScanOutputDot(object): ...@@ -157,12 +157,11 @@ class TestPushOutScanOutputDot(object):
# Compile the function twice, once with the optimization and once # Compile the function twice, once with the optimization and once
# without # without
f_opt = theano.function([v, m], T.jacobian(output, v)) opt_mode = mode.including("scan")
f_opt = theano.function([v, m], T.jacobian(output, v), mode=opt_mode)
default_mode = theano.compile.get_default_mode() no_opt_mode = mode.excluding("scanOp_pushout_output")
default_mode.excluding("scanOp_pushout_output") f_no_opt = theano.function([v, m], T.jacobian(output, v), mode=no_opt_mode)
f_no_opt = theano.function([v, m], T.jacobian(output, v),
mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should # The inner function of scan should have only one output and it should
...@@ -248,11 +247,11 @@ class TestPushOutScanOutputDot(object): ...@@ -248,11 +247,11 @@ class TestPushOutScanOutputDot(object):
# Compile the function twice, once with the optimization and once # Compile the function twice, once with the optimization and once
# without # without
f_opt = theano.function([a, b], outputs) opt_mode = mode.including("scan")
f_opt = theano.function([a, b], outputs, mode=opt_mode)
default_mode = theano.compile.get_default_mode() no_opt_mode = mode.excluding("scanOp_pushout_output")
default_mode.excluding("scanOp_pushout_output") f_no_opt = theano.function([a, b], outputs, mode=no_opt_mode)
f_no_opt = theano.function([a, b], outputs, mode=default_mode)
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should # The inner function of scan should have only one output and it should
...@@ -272,3 +271,150 @@ class TestPushOutScanOutputDot(object): ...@@ -272,3 +271,150 @@ class TestPushOutScanOutputDot(object):
utt.assert_allclose(output_opt[0], output_no_opt[0]) utt.assert_allclose(output_opt[0], output_no_opt[0])
utt.assert_allclose(output_opt[1], output_no_opt[1]) utt.assert_allclose(output_opt[1], output_no_opt[1])
class TestPushOutSumOfDot():
"""
Test case for the PushOutScanOutput optimizer in the case where the scan
is used to compute the sum over the dot products between the corresponding
elements of two list of matrices.
"""
def test_machine_translation(self):
"""
This test case comes from https://github.com/rizar/scan-grad-speed and
is an example of actual computation done with scan in the context of
machine translation
'dim' has been reduced from 1000 to 5 to make the test run faster
"""
# Parameters from an actual machine tranlation run
batch_size = 80
seq_len = 50
n_words = 80 * 50
dim = 5
# Weight matrices
U = theano.shared(numpy.random.normal(size=(dim, dim),
scale=0.0001).astype(config.floatX))
U.name = 'U'
V = theano.shared(U.get_value())
V.name = 'V'
W = theano.shared(U.get_value())
W.name = 'W'
# Variables and their values
x = T.tensor3('x')
x_value = numpy.random.normal(size=(seq_len, batch_size, dim),
scale=0.0001).astype(config.floatX)
ri = T.tensor3('ri')
ri_value = x_value
zi = T.tensor3('zi')
zi_value = x_value
init = T.alloc(numpy.cast[config.floatX](0), batch_size, dim)
def rnn_step1(
# sequences
x, ri, zi,
# outputs_info
h):
pre_r = ri + h.dot(U)
pre_z = zi + h.dot(V)
r = T.nnet.sigmoid(pre_r)
z = T.nnet.sigmoid(pre_z)
after_r = r * h
pre_h = x + after_r.dot(W)
new_h = T.tanh(pre_h)
res_h = z * new_h + (1 - z) * h
return res_h
# Compile the function twice, once with the optimization and once
# without
opt_mode = mode.including("scan")
h, _ = theano.scan(rnn_step1, sequences=[x, ri, zi], n_steps=seq_len,
outputs_info=init, name='fpass1', mode=opt_mode)
cost = h[-1].sum()
grad1 = T.grad(cost, [U, V, W])
f_opt = theano.function(inputs=[x, ri, zi], outputs=grad1,
mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output")
h, _ = theano.scan(rnn_step1, sequences=[x, ri, zi], n_steps=seq_len,
outputs_info=init, name='fpass1', mode=no_opt_mode)
cost = h[-1].sum()
grad1 = T.grad(cost, [U, V, W])
f_no_opt = theano.function(inputs=[x, ri, zi], outputs=grad1,
mode=no_opt_mode)
# Validate that the optimization has been applied
scan_node_grad = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][1]
for output in scan_node_grad.op.outputs:
assert not (isinstance(output.owner.op, T.elemwise.Elemwise) and
any([isinstance(i, T.Dot) for i
in output.owner.inputs]))
# Compare the outputs of the two functions on the same input data.
f_opt_output = f_opt(x_value, ri_value, zi_value)
f_no_opt_output = f_no_opt(x_value, ri_value, zi_value)
utt.assert_allclose(f_opt_output, f_no_opt_output)
def test_non_zero_init(self):
"""
Test the case where the initial value for the nitsot output is
non-zero
"""
input1 = T.tensor3()
input2 = T.tensor3()
input3 = T.tensor3()
W = theano.shared(numpy.random.normal(size=(4, 5))).astype(config.floatX)
U = theano.shared(numpy.random.normal(size=(6, 7))).astype(config.floatX)
def inner_fct(seq1, seq2, seq3, previous_output):
temp1 = T.dot(seq1, W) + seq3
temp2 = T.dot(seq2, U)
dot_output = T.dot(temp1, temp2)
return previous_output + dot_output
init = T.as_tensor_variable(numpy.random.normal(size=(3,7)))
# Compile the function twice, once with the optimization and once
# without
opt_mode = mode.including("scan")
h, _ = theano.scan(inner_fct,
sequences=[input1, input2, input3],
outputs_info=init,
mode=opt_mode)
output = h[-1]
f_opt = theano.function([input1, input2, input3], output,
mode=opt_mode)
no_opt_mode = mode.excluding("scanOp_pushout_output")
h, _ = theano.scan(inner_fct,
sequences=[input1, input2, input3],
outputs_info=init,
mode=no_opt_mode)
output = h[-1]
f_no_opt = theano.function([input1, input2, input3], output,
mode=no_opt_mode)
# Ensure that the optimization has been applied for f_opt
# TODO
# Compare the outputs of the 2 functions
input1_value = numpy.random.random((2, 3, 4)).astype(config.floatX)
input2_value = numpy.random.random((2, 5, 6)).astype(config.floatX)
input3_value = numpy.random.random((2, 3, 5)).astype(config.floatX)
output_opt = f_opt(input1_value, input2_value, input3_value)
output_no_opt = f_no_opt(input1_value, input2_value, input3_value)
utt.assert_allclose(output_opt, output_no_opt)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论