提交 0d895a46 authored 作者: Frederic's avatar Frederic

do not try cudnn opt when it isn't available

上级 c0217dd1
...@@ -1545,42 +1545,42 @@ if True: ...@@ -1545,42 +1545,42 @@ if True:
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConv, alpha_in=4, nd=4) @alpha_merge(GpuDnnConv, alpha_in=4, nd=4)
def local_dnn_conv_alpha_merge(node, *inputs): def local_dnn_conv_alpha_merge(node, *inputs):
if version() == -1: if not dnn_available() or version() == -1:
return None return None
return [GpuDnnConv(workmem=node.op.workmem)(*inputs)] return [GpuDnnConv(workmem=node.op.workmem)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConvGradW, alpha_in=4, nd=4) @alpha_merge(GpuDnnConvGradW, alpha_in=4, nd=4)
def local_dnn_convw_alpha_merge(node, *inputs): def local_dnn_convw_alpha_merge(node, *inputs):
if version() == -1: if not dnn_available() or version() == -1:
return None return None
return [GpuDnnConvGradW()(*inputs)] return [GpuDnnConvGradW()(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConvGradI, alpha_in=4, nd=4) @alpha_merge(GpuDnnConvGradI, alpha_in=4, nd=4)
def local_dnn_convi_alpha_merge(node, *inputs): def local_dnn_convi_alpha_merge(node, *inputs):
if version() == -1: if not dnn_available() or version() == -1:
return None return None
return [GpuDnnConvGradI()(*inputs)] return [GpuDnnConvGradI()(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConv, alpha_in=4, out_in=2, nd=4) @output_merge(GpuDnnConv, alpha_in=4, out_in=2, nd=4)
def local_dnn_conv_output_merge(node, *inputs): def local_dnn_conv_output_merge(node, *inputs):
if version() == -1: if not dnn_available() or version() == -1:
return None return None
return [GpuDnnConv(workmem=node.op.workmem)(*inputs)] return [GpuDnnConv(workmem=node.op.workmem)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConvGradW, alpha_in=4, out_in=2, nd=4) @output_merge(GpuDnnConvGradW, alpha_in=4, out_in=2, nd=4)
def local_dnn_convw_output_merge(node, *inputs): def local_dnn_convw_output_merge(node, *inputs):
if version() == -1: if not dnn_available() or version() == -1:
return None return None
return [GpuDnnConvGradW()(*inputs)] return [GpuDnnConvGradW()(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConvGradI, alpha_in=4, out_in=2, nd=4) @output_merge(GpuDnnConvGradI, alpha_in=4, out_in=2, nd=4)
def local_dnn_convi_output_merge(node, *inputs): def local_dnn_convi_output_merge(node, *inputs):
if version() == -1: if not dnn_available() or version() == -1:
return None return None
return [GpuDnnConvGradI()(*inputs)] return [GpuDnnConvGradI()(*inputs)]
......
...@@ -455,7 +455,10 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -455,7 +455,10 @@ class TestDnnInferShapes(utt.InferShapeTester):
dnn.GpuDnnPoolGrad dnn.GpuDnnPoolGrad
) )
def test_dnn_conv_merge(): def test_dnn_conv_merge():
if not cuda.dnn.dnn_available() or cuda.dnn.version() == -1:
raise SkipTest(cuda.dnn.dnn_available.msg)
img = T.ftensor4() img = T.ftensor4()
kern = T.ftensor4() kern = T.ftensor4()
out = T.ftensor4() out = T.ftensor4()
...@@ -516,7 +519,7 @@ def test_dnn_conv_merge(): ...@@ -516,7 +519,7 @@ def test_dnn_conv_merge():
def test_dnn_conv_grad(): def test_dnn_conv_grad():
if dnn.version() == -1: if not cuda.dnn.dnn_available() or dnn.version() == -1:
raise SkipTest('alpha != 1.0 not supported in cudnn v1') raise SkipTest('alpha != 1.0 not supported in cudnn v1')
b = 1 b = 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论