提交 242145f6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix test.

上级 f722afba
...@@ -1482,17 +1482,13 @@ def test_dnn_rnn_gru(): ...@@ -1482,17 +1482,13 @@ def test_dnn_rnn_gru():
model.get_params()) model.get_params())
cudnn_fn, cudnn_grad_fn = funcs(rnnb.apply(params_cudnn, X, h0)[0], cudnn_fn, cudnn_grad_fn = funcs(rnnb.apply(params_cudnn, X, h0)[0],
[params_cudnn]) [params_cudnn])
# Make a copy of the params
params_cudnn2 = gpuarray_shared_constructor(params_cudnn.get_value())
y, hy = rnnb.apply(params_cudnn2, X, h0)
# Test with grad connected to both y and hy # Test with grad connected to both y and hy
y, hy = rnnb.apply(params_cudnn, X, h0)
cudnn2_fn, cudnn2_grad_fn = funcs((y + hy[-1]) / 2, cudnn2_fn, cudnn2_grad_fn = funcs((y + hy[-1]) / 2,
[params_cudnn]) [params_cudnn])
# Make a copy of the params
params_cudnn3 = gpuarray_shared_constructor(params_cudnn.get_value())
y, hy = rnnb.apply(params_cudnn3, X, h0)
# Test with grad connected to both y and hy # Test with grad connected to both y and hy
y, hy = rnnb.apply(params_cudnn, X, h0)
cudnn3_fn, cudnn3_grad_fn = funcs(hy[-1], cudnn3_fn, cudnn3_grad_fn = funcs(hy[-1],
[params_cudnn]) [params_cudnn])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论