提交 8e371199 authored 作者: Frederic Bastien's avatar Frederic Bastien

Move V100/cuDNN check at the same place as other cudnn check.

上级 b4f39855
...@@ -99,17 +99,6 @@ def init_dev(dev, name=None, preallocate=None): ...@@ -99,17 +99,6 @@ def init_dev(dev, name=None, preallocate=None):
MB = (1024 * 1024) MB = (1024 * 1024)
if dev.startswith('cuda'): if dev.startswith('cuda'):
avail = dnn.dnn_available(name) avail = dnn.dnn_available(name)
# On V100, cuDNN lower then 7002 don't raise error but
# takes hours to load! So raise a good user error.
if avail and dnn.version() < 7002:
bin_id = context.bin_id
assert bin_id.startswith("compute_"), context_bin_id
if int(bin_id[8:]) >= 70:
raise RuntimeError(
"You have cuDNN version %d, while the GPU is a Volta"
" genaration or more recent. This cause extreme"
" slowness, so we disable it."
" Use cuDNN 7.0.2 or higher." % (dnn.version()))
# If we try to enable cudnn and there isn't enough GPU # If we try to enable cudnn and there isn't enough GPU
# memory, there will be an unclear error message. So do # memory, there will be an unclear error message. So do
# not even try a clear error. # not even try a clear error.
......
...@@ -222,10 +222,16 @@ def dnn_available(context_name): ...@@ -222,10 +222,16 @@ def dnn_available(context_name):
# This is a hack because bin_id is in the from of # This is a hack because bin_id is in the from of
# "<something>_<major><minor>" for cuda devices. # "<something>_<major><minor>" for cuda devices.
if ctx.bin_id[-2:] < b'30': if int(ctx.bin_id[-2:]) < 30:
dnn_available.msg = "Device not supported" dnn_available.msg = "Device not supported"
return False return False
# On V100, cuDNN lower then 7002 don't raise error but
# takes hours to load or execute! So raise a good user error.
if version() < 7002:
if int(ctx.bin_id[-2:]) >= 70:
dnn_available.msg = "Use cuDNN 7.0.2 or higher for Volta."
return False
return True return True
dnn_available.msg = None dnn_available.msg = None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论