提交 4895a28c authored 作者: Frederic Bastien's avatar Frederic Bastien

Remove useless code as the new back-end don't support older cudnn anywhere now.

上级 95d3d9cb
...@@ -1508,8 +1508,6 @@ def test_dnn_batchnorm_train_without_running_averages(): ...@@ -1508,8 +1508,6 @@ def test_dnn_batchnorm_train_without_running_averages():
# compile and run batch_normalization_train without running averages # compile and run batch_normalization_train without running averages
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
if dnn.version(raises=False) < 5000:
raise SkipTest("batch normalization requires cudnn v5+")
utt.seed_rng() utt.seed_rng()
x, scale, bias, dy = T.tensor4('x'), T.tensor4('scale'), T.tensor4('bias'), T.tensor4('dy') x, scale, bias, dy = T.tensor4('x'), T.tensor4('scale'), T.tensor4('bias'), T.tensor4('dy')
...@@ -1593,8 +1591,6 @@ def test_dnn_batchnorm_train_inplace(): ...@@ -1593,8 +1591,6 @@ def test_dnn_batchnorm_train_inplace():
# test inplace_running_mean and inplace_running_var # test inplace_running_mean and inplace_running_var
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
if dnn.version(raises=False) < 5000:
raise SkipTest("batch normalization requires cudnn v5+")
utt.seed_rng() utt.seed_rng()
x, scale, bias = T.tensor4('x'), T.tensor4('scale'), T.tensor4('bias') x, scale, bias = T.tensor4('x'), T.tensor4('scale'), T.tensor4('bias')
...@@ -1717,8 +1713,6 @@ def test_batchnorm_inference_inplace(): ...@@ -1717,8 +1713,6 @@ def test_batchnorm_inference_inplace():
# test inplace # test inplace
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
if dnn.version(raises=False) < 5000:
raise SkipTest("batch normalization requires cudnn v5+")
utt.seed_rng() utt.seed_rng()
x, scale, bias, mean, var = (T.tensor4(n) for n in ('x', 'scale', 'bias', 'mean', 'var')) x, scale, bias, mean, var = (T.tensor4(n) for n in ('x', 'scale', 'bias', 'mean', 'var'))
...@@ -1746,8 +1740,6 @@ def test_batchnorm_inference_inplace(): ...@@ -1746,8 +1740,6 @@ def test_batchnorm_inference_inplace():
def test_dnn_batchnorm_valid_and_invalid_axes(): def test_dnn_batchnorm_valid_and_invalid_axes():
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
if dnn.version(raises=False) < 5000:
raise SkipTest("batch normalization requires cudnn v5+")
for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix): for vartype in (T.tensor5, T.tensor4, T.tensor3, T.matrix):
x, scale, bias, mean, var, dy = (vartype(n) x, scale, bias, mean, var, dy = (vartype(n)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论