提交 e4d68156 authored 作者: --global's avatar --global

Add test case

上级 2569572b
...@@ -702,6 +702,30 @@ class T_Scan(unittest.TestCase): ...@@ -702,6 +702,30 @@ class T_Scan(unittest.TestCase):
expected_result = [1, 2, 2] expected_result = [1, 2, 2]
assert(result == expected_result) assert(result == expected_result)
def test_grad_grad_mitsot_sitsot(self):
# Test for an index error when taking the second derivative
# through a Scan node with one sitsot and one mitsot.
def inner_fct(mitsot_m2, mitsot_m1, sitsot):
total = mitsot_m2 + mitsot_m1 + sitsot
output = total ** 2
return output, output
inputs = [tensor.matrix(), tensor.vector()]
outputs_info = [dict(initial=inputs[0], taps=[-2, -1]), inputs[1]]
scan_outputs, updates = theano.scan(fn=inner_fct,
outputs_info=outputs_info,
n_steps=5)
# Take the gradient of each output wrt its corresponding initial state
gradients = [theano.grad(scan_outputs[0].sum(), inputs[0]),
theano.grad(scan_outputs[1].sum(), inputs[1])]
# Take the gradient of the sum of gradients wrt the inputs
sum_of_grads = sum([g.sum() for g in gradients])
second_gradients = theano.grad(sum_of_grads, inputs[0])
def test_grad_two_scans(self): def test_grad_two_scans(self):
# data input & output # data input & output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论