提交 cde7f398 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make a method to list all registered context and use it in NoCudnnRaise.

上级 98da3fcc
......@@ -16,7 +16,7 @@ from theano.tensor.signal.downsample import (
DownsampleFactorMax, MaxPoolGrad, AveragePoolGrad)
from . import pygpu
from .type import get_context, gpu_context_type
from .type import get_context, gpu_context_type, list_contexts
from .basic_ops import (as_gpuarray_variable, infer_context_name,
gpu_contiguous, HostFromGpu,
GpuAllocEmpty, empty_like)
......@@ -82,26 +82,31 @@ def _dnn_check_version():
return True, None
def dnn_available(context_name):
if dnn_available.avail is False:
return False
def dnn_present():
if dnn_present.avail is not None:
return dnn_present.avail
if pygpu is None:
dnn_available.msg = "PyGPU not available"
dnn_available.avail = False
dnn_present.msg = "PyGPU not available"
dnn_present.avail = False
return False
dnn_present.avail, dnn_present.msg = _dnn_check_compile()
if dnn_present.avail:
dnn_present.avail, dnn_present.msg = _dnn_check_version()
if not dnn_present.avail:
raise RuntimeError(dnn_present.msg)
return dnn_present.avail
dnn_present.avail = None
dnn_present.msg = None
def dnn_available(context_name):
if not dnn_present():
return False
# If we haven't checked yet, check if we can compile.
if dnn_available.avail is None:
dnn_available.avail, dnn_available.msg = _dnn_check_compile()
if dnn_available.avail:
dnn_available.avail, dnn_available.msg = _dnn_check_version()
if not dnn_available.avail:
raise RuntimeError(dnn_available.msg)
if not dnn_available.avail:
return False
# Don't cache these checks since they depend on the context
ctx = get_context(context_name)
if not ctx.kind == 'cuda':
......@@ -116,9 +121,6 @@ def dnn_available(context_name):
return True
dnn_available.avail = None
dnn_available.msg = None
class DnnBase(COp):
"""
......@@ -203,9 +205,7 @@ def version():
This also does a check that the header version matches the runtime version.
"""
if dnn_available.avail is None:
raise RuntimeError("called version() before dnn_available()")
if not dnn_available.avail:
if not dnn_present():
raise Exception(
"We can't determine the cudnn version as it is not available",
dnn_available.msg)
......@@ -1407,17 +1407,15 @@ def local_softmax_dnn(node):
@register_opt('cudnn')
@local_optimizer([GpuElemwise])
def local_log_softmax_dnn(node):
if version() < 3000:
# No log-softmax before cudnn v3
return
# This looks for GpuDnnSoftmax so we know that we have cudnn.
if (isinstance(node.op, GpuElemwise) and
isinstance(node.op.scalar_op, Log) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, GpuDnnSoftmax) and
len(node.inputs[0].clients) == 1):
# Don't move this call to version outside the condition, it
# needs to be here.
if version() < 3000:
# No log-softmax before cudnn v3
return
softmax_node = node.inputs[0].owner
new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode)
return [new_softmax(softmax_node.inputs[0])]
......@@ -1429,20 +1427,14 @@ class NoCuDNNRaise(Optimizer):
Raise a RuntimeError if cudnn can't be used.
"""
try:
dnn_available(None)
except ValueError:
# This is most likely due to get_context()
pass
# This means we will have a problem no matter what context.
if not dnn_available.avail:
# Make an assert error as we want Theano to fail, not
# just skip this optimization.
raise AssertionError(
"cuDNN optimization was enabled, but Theano was not able"
" to use it. We got this error: \n" +
dnn_available.msg)
for c in list_contexts():
if not dnn_available(c):
# Make an assert error as we want Theano to fail, not
# just skip this optimization.
raise AssertionError(
"cuDNN optimization was enabled, but Theano was not able"
" to use it. We got this error: \n" +
dnn_available.msg)
gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, 'cudnn')
......
......@@ -60,6 +60,10 @@ def get_context(name):
return _context_reg[name]
def list_contexts():
return _context_reg.values()
# Private method
def _name_for_ctx(ctx):
for k, v in _context_reg:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论