提交 00932b3c authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Update hessian_bug_grad_grad test to run faster

上级 47021839
...@@ -3775,26 +3775,26 @@ class T_Scan(unittest.TestCase): ...@@ -3775,26 +3775,26 @@ class T_Scan(unittest.TestCase):
inp = scan_node.op.outer_non_seqs(scan_node) inp = scan_node.op.outer_non_seqs(scan_node)
assert len(inp) == 1 assert len(inp) == 1
@attr('slow')
def test_hessian_bug_grad_grad_two_scans(self): def test_hessian_bug_grad_grad_two_scans(self):
#Bug reported by Bitton Tenessi #Bug reported by Bitton Tenessi
# NOTE : The test to reproduce the bug reported by Bitton Tenessi
# was modified from its original version to be faster to run.
W_flat = tensor.fvector(name='W') W = tensor.fvector(name='W')
W_flat.tag.test_value = numpy.ones((8,), dtype=numpy.float32) n_steps = tensor.iscalar(name='Nb_steps')
W = W_flat.reshape((2, 2, 2))
def loss_outer(i_outer, sum_outer, W): def loss_outer(sum_outer, W):
def loss_inner(i_inner, sum_inner, W): def loss_inner(sum_inner, W):
return sum_inner + (W**2).sum().sum().sum() return sum_inner + (W**2).sum()
result_inner, _ = theano.scan( result_inner, _ = theano.scan(
fn=loss_inner, fn=loss_inner,
outputs_info=tensor.as_tensor_variable( outputs_info=tensor.as_tensor_variable(
numpy.asarray(0, dtype=numpy.float32)), numpy.asarray(0, dtype=numpy.float32)),
sequences=tensor.arange(1, dtype='int32'),
non_sequences=[W], non_sequences=[W],
n_steps=1,
) )
return sum_outer + result_inner[-1] return sum_outer + result_inner[-1]
...@@ -3802,15 +3802,15 @@ class T_Scan(unittest.TestCase): ...@@ -3802,15 +3802,15 @@ class T_Scan(unittest.TestCase):
fn=loss_outer, fn=loss_outer,
outputs_info=tensor.as_tensor_variable( outputs_info=tensor.as_tensor_variable(
numpy.asarray(0, dtype=numpy.float32)), numpy.asarray(0, dtype=numpy.float32)),
sequences=tensor.arange(1, dtype='int32'),
non_sequences=[W], non_sequences=[W],
n_steps=n_steps,
) )
cost = result_outer[-1] cost = result_outer[-1]
H = theano.gradient.hessian(cost, W_flat) H = theano.gradient.hessian(cost, W)
print >> sys.stderr, "." print >> sys.stderr, "."
f = theano.function([W_flat], H) f = theano.function([W, n_steps], H)
f(numpy.ones((8,), dtype='float32')) f(numpy.ones((8,), dtype='float32'), 1)
def test_speed(): def test_speed():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论