提交 0d895a46 authored 作者: Frederic's avatar Frederic

do not try cudnn opt when it isn't available

上级 c0217dd1
......@@ -1545,42 +1545,42 @@ if True:
@register_opt('cudnn')
@alpha_merge(GpuDnnConv, alpha_in=4, nd=4)
def local_dnn_conv_alpha_merge(node, *inputs):
if version() == -1:
if not dnn_available() or version() == -1:
return None
return [GpuDnnConv(workmem=node.op.workmem)(*inputs)]
@register_opt('cudnn')
@alpha_merge(GpuDnnConvGradW, alpha_in=4, nd=4)
def local_dnn_convw_alpha_merge(node, *inputs):
if version() == -1:
if not dnn_available() or version() == -1:
return None
return [GpuDnnConvGradW()(*inputs)]
@register_opt('cudnn')
@alpha_merge(GpuDnnConvGradI, alpha_in=4, nd=4)
def local_dnn_convi_alpha_merge(node, *inputs):
if version() == -1:
if not dnn_available() or version() == -1:
return None
return [GpuDnnConvGradI()(*inputs)]
@register_opt('cudnn')
@output_merge(GpuDnnConv, alpha_in=4, out_in=2, nd=4)
def local_dnn_conv_output_merge(node, *inputs):
if version() == -1:
if not dnn_available() or version() == -1:
return None
return [GpuDnnConv(workmem=node.op.workmem)(*inputs)]
@register_opt('cudnn')
@output_merge(GpuDnnConvGradW, alpha_in=4, out_in=2, nd=4)
def local_dnn_convw_output_merge(node, *inputs):
if version() == -1:
if not dnn_available() or version() == -1:
return None
return [GpuDnnConvGradW()(*inputs)]
@register_opt('cudnn')
@output_merge(GpuDnnConvGradI, alpha_in=4, out_in=2, nd=4)
def local_dnn_convi_output_merge(node, *inputs):
if version() == -1:
if not dnn_available() or version() == -1:
return None
return [GpuDnnConvGradI()(*inputs)]
......
......@@ -455,7 +455,10 @@ class TestDnnInferShapes(utt.InferShapeTester):
dnn.GpuDnnPoolGrad
)
def test_dnn_conv_merge():
if not cuda.dnn.dnn_available() or cuda.dnn.version() == -1:
raise SkipTest(cuda.dnn.dnn_available.msg)
img = T.ftensor4()
kern = T.ftensor4()
out = T.ftensor4()
......@@ -516,7 +519,7 @@ def test_dnn_conv_merge():
def test_dnn_conv_grad():
if dnn.version() == -1:
if not cuda.dnn.dnn_available() or dnn.version() == -1:
raise SkipTest('alpha != 1.0 not supported in cudnn v1')
b = 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论