提交 2c13f91b authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix crash during opt when we where looking for cudnn version.

上级 89be3bcf
...@@ -215,16 +215,22 @@ class DnnVersion(Op): ...@@ -215,16 +215,22 @@ class DnnVersion(Op):
return None return None
def version(): def version(raises=True):
""" """
Return the current cuDNN version we link with. Return the current cuDNN version we link with.
This also does a check that the header version matches the runtime version. This also does a check that the header version matches the runtime version.
:raises: If True, raise an exception if CuDNN is not present or badly installed.
Otherwise, return -1.
""" """
if not dnn_present(): if not dnn_present():
if raises:
raise Exception( raise Exception(
"We can't determine the cudnn version as it is not available", "We can't determine the cudnn version as it is not available",
dnn_available.msg) dnn_available.msg)
else:
return -1
if version.v is None: if version.v is None:
f = theano.function([], DnnVersion()(), f = theano.function([], DnnVersion()(),
...@@ -1204,7 +1210,7 @@ class GpuDnnSoftmaxBase(DnnBase): ...@@ -1204,7 +1210,7 @@ class GpuDnnSoftmaxBase(DnnBase):
DnnBase.__init__(self, [self.file], self.c_func) DnnBase.__init__(self, [self.file], self.c_func)
assert(algo in ('fast', 'accurate', 'log')) assert(algo in ('fast', 'accurate', 'log'))
if algo == 'log' and version() < 3000: if algo == 'log' and version(raises=False) < 3000:
raise RuntimeError("Need CuDNN v3 for log-softmax") raise RuntimeError("Need CuDNN v3 for log-softmax")
self.algo = algo self.algo = algo
...@@ -1485,15 +1491,15 @@ def local_softmax_dnn(node): ...@@ -1485,15 +1491,15 @@ def local_softmax_dnn(node):
@register_opt('cudnn') @register_opt('cudnn')
@local_optimizer([GpuElemwise]) @local_optimizer([GpuElemwise])
def local_log_softmax_dnn(node): def local_log_softmax_dnn(node):
if version() < 3000:
# No log-softmax before cudnn v3
return
# This looks for GpuDnnSoftmax so we know that we have cudnn. # This looks for GpuDnnSoftmax so we know that we have cudnn.
if (isinstance(node.op, GpuElemwise) and if (isinstance(node.op, GpuElemwise) and
isinstance(node.op.scalar_op, Log) and isinstance(node.op.scalar_op, Log) and
node.inputs[0].owner and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, GpuDnnSoftmax) and isinstance(node.inputs[0].owner.op, GpuDnnSoftmax) and
len(node.inputs[0].clients) == 1): len(node.inputs[0].clients) == 1):
if version(raises=False) < 3000:
# No log-softmax before cudnn v3
raise_no_cudnn("Need CuDNN v3 for LogSoftmax")
softmax_node = node.inputs[0].owner softmax_node = node.inputs[0].owner
new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode) new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode)
return [new_softmax(softmax_node.inputs[0])] return [new_softmax(softmax_node.inputs[0])]
...@@ -1502,14 +1508,14 @@ def local_log_softmax_dnn(node): ...@@ -1502,14 +1508,14 @@ def local_log_softmax_dnn(node):
@register_opt('cudnn') @register_opt('cudnn')
@op_lifter([LogSoftmax]) @op_lifter([LogSoftmax])
def local_logsoftmax_to_dnn(node, ctx_name): def local_logsoftmax_to_dnn(node, ctx_name):
if not dnn_available(ctx_name) or version() < 3000:
# No log-softmax before cudnn v3
raise_no_cudnn("Need CuDNN v3 for LogSoftmax")
# Transform the input in the format expected by GpuDnnSoftmax # Transform the input in the format expected by GpuDnnSoftmax
inp = node.inputs[0] inp = node.inputs[0]
if inp.ndim != 2: if inp.ndim != 2:
return return
if not dnn_available(ctx_name) or version(raises=False) < 3000:
# No log-softmax before cudnn v3
raise_no_cudnn("Need CuDNN v3 for LogSoftmax")
inp = inp.dimshuffle(0, 1, 'x', 'x') inp = inp.dimshuffle(0, 1, 'x', 'x')
inp.tag.context_name = ctx_name inp.tag.context_name = ctx_name
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论