提交 def70fb8 authored 作者: Frederic's avatar Frederic

Fix tests. Add back old opt tag and test the new opt tag and the old one.

上级 4bc20eb1
......@@ -1337,11 +1337,13 @@ conv_groupopt.register('conv_fft_full', local_conv_fft_full, 10,
# It can be disabled by excluding 'conv_dnn' or 'cudnn'.
from . import dnn
if dnn.dnn_available():
conv_groupopt.register('conv_dnn', dnn.local_conv_dnn, 20,
conv_groupopt.register('local_conv_dnn', dnn.local_conv_dnn, 20,
'conv_dnn',
'fast_compile', 'fast_run', 'cudnn')
# The GEMM-based convolution comes last to catch all remaining cases.
# It can be disabled by excluding 'conv_gemm'.
conv_groupopt.register('conv_gemm', local_conv_gemm, 30,
conv_groupopt.register('local_conv_gemm', local_conv_gemm, 30,
'conv_gemm',
'fast_compile', 'fast_run')
......
......@@ -616,7 +616,13 @@ def test_default_conv():
assert any([isinstance(a.op, cuda.blas.GpuCorrMM)
for a in f.maker.fgraph.apply_nodes])
mode = theano_mode.excluding('local_gpu_conv', 'local_conv_gemm')
mode = theano_mode.excluding('local_conv_dnn', 'local_conv_gemm')
f = theano.function([img, fil], c, mode=mode)
assert any([isinstance(a.op, cuda.blas.GpuConv)
for a in f.maker.fgraph.apply_nodes])
mode = theano_mode.excluding('conv_dnn', 'conv_gemm')
f = theano.function([img, fil], c, mode=mode)
assert any([isinstance(a.op, cuda.blas.GpuConv)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论