提交 fdde035a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Update stuff for the final release of v3.

上级 c2287027
......@@ -82,18 +82,20 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
else:
# If we can compile, check that we can import and run.
v = version()
if v == -1:
if v < 2000:
dnn_available.avail = False
dnn_available.msg = (
"You have CuDNN v1 installed, upgrade to v2 or more recent.")
"You have an old release of CuDNN (or a release candidate) "
"that isn't supported. Please update to at least v2 final "
"version.")
raise RuntimeError(dnn_available.msg)
if v == 20:
if v >= 3000 and v < 3007:
dnn_available.avail = False
dnn_available.msg = (
"You have installed a release candidate of CuDNN v2."
" This isn't supported anymore."
" Update to CuDNN v2 final version.")
"You have installed a release candidate of CuDNN v3. This "
"isn't supported. Please update to v3 final version.")
raise RuntimeError(dnn_available.msg)
return dnn_available.avail
dnn_available.avail = None
......@@ -507,7 +509,7 @@ class GpuDnnConvGradW(DnnBase):
def __init__(self, inplace=False, algo=None):
DnnBase.__init__(self, ["dnn_conv_base.c", "dnn_gw.c"],
"APPLY_SPECIFIC(conv_gw)")
"APPLY_SPECIFIC(conv_gw)")
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
......@@ -787,7 +789,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
# Special case: We can be faster by using GpuDnnConvGradI to compute
# the full convolution as the backward pass of a valid convolution.
# We just need to set up a suitable 'fake' valid convolution.
img = gpu_contiguous(img) # cudnn v1 and v2 rc3 need contiguous data
img = gpu_contiguous(img) # cudnn v2 rc3 need contiguous data
kerns = gpu_contiguous(kerns.dimshuffle(1, 0, 2, 3))
conv_mode = 'cross' if conv_mode == 'conv' else 'conv'
shape2 = shape_i(img, 2, fgraph) + shape_i(kerns, 2, fgraph) - 1
......@@ -1429,8 +1431,7 @@ gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, 'cudnn')
@register_opt('cudnn')
@op_lifter([SoftmaxGrad])
def local_softmax_dnn_grad(node):
if not dnn_available() or version() != 2000:
# softmaxgrad (n, c, 1, 1) broken in v3 rc1
if not dnn_available():
return
ins = []
for n in node.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论