提交 cde7f398 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make a method to list all registered context and use it in NoCudnnRaise.

上级 98da3fcc
...@@ -16,7 +16,7 @@ from theano.tensor.signal.downsample import ( ...@@ -16,7 +16,7 @@ from theano.tensor.signal.downsample import (
DownsampleFactorMax, MaxPoolGrad, AveragePoolGrad) DownsampleFactorMax, MaxPoolGrad, AveragePoolGrad)
from . import pygpu from . import pygpu
from .type import get_context, gpu_context_type from .type import get_context, gpu_context_type, list_contexts
from .basic_ops import (as_gpuarray_variable, infer_context_name, from .basic_ops import (as_gpuarray_variable, infer_context_name,
gpu_contiguous, HostFromGpu, gpu_contiguous, HostFromGpu,
GpuAllocEmpty, empty_like) GpuAllocEmpty, empty_like)
...@@ -82,26 +82,31 @@ def _dnn_check_version(): ...@@ -82,26 +82,31 @@ def _dnn_check_version():
return True, None return True, None
def dnn_available(context_name): def dnn_present():
if dnn_available.avail is False: if dnn_present.avail is not None:
return False return dnn_present.avail
if pygpu is None: if pygpu is None:
dnn_available.msg = "PyGPU not available" dnn_present.msg = "PyGPU not available"
dnn_available.avail = False dnn_present.avail = False
return False return False
# If we haven't checked yet, check if we can compile. dnn_present.avail, dnn_present.msg = _dnn_check_compile()
if dnn_available.avail is None: if dnn_present.avail:
dnn_available.avail, dnn_available.msg = _dnn_check_compile() dnn_present.avail, dnn_present.msg = _dnn_check_version()
if dnn_available.avail: if not dnn_present.avail:
dnn_available.avail, dnn_available.msg = _dnn_check_version() raise RuntimeError(dnn_present.msg)
if not dnn_available.avail:
raise RuntimeError(dnn_available.msg) return dnn_present.avail
if not dnn_available.avail:
dnn_present.avail = None
dnn_present.msg = None
def dnn_available(context_name):
if not dnn_present():
return False return False
# Don't cache these checks since they depend on the context
ctx = get_context(context_name) ctx = get_context(context_name)
if not ctx.kind == 'cuda': if not ctx.kind == 'cuda':
...@@ -116,9 +121,6 @@ def dnn_available(context_name): ...@@ -116,9 +121,6 @@ def dnn_available(context_name):
return True return True
dnn_available.avail = None
dnn_available.msg = None
class DnnBase(COp): class DnnBase(COp):
""" """
...@@ -203,9 +205,7 @@ def version(): ...@@ -203,9 +205,7 @@ def version():
This also does a check that the header version matches the runtime version. This also does a check that the header version matches the runtime version.
""" """
if dnn_available.avail is None: if not dnn_present():
raise RuntimeError("called version() before dnn_available()")
if not dnn_available.avail:
raise Exception( raise Exception(
"We can't determine the cudnn version as it is not available", "We can't determine the cudnn version as it is not available",
dnn_available.msg) dnn_available.msg)
...@@ -1407,17 +1407,15 @@ def local_softmax_dnn(node): ...@@ -1407,17 +1407,15 @@ def local_softmax_dnn(node):
@register_opt('cudnn') @register_opt('cudnn')
@local_optimizer([GpuElemwise]) @local_optimizer([GpuElemwise])
def local_log_softmax_dnn(node): def local_log_softmax_dnn(node):
if version() < 3000:
# No log-softmax before cudnn v3
return
# This looks for GpuDnnSoftmax so we know that we have cudnn. # This looks for GpuDnnSoftmax so we know that we have cudnn.
if (isinstance(node.op, GpuElemwise) and if (isinstance(node.op, GpuElemwise) and
isinstance(node.op.scalar_op, Log) and isinstance(node.op.scalar_op, Log) and
node.inputs[0].owner and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, GpuDnnSoftmax) and isinstance(node.inputs[0].owner.op, GpuDnnSoftmax) and
len(node.inputs[0].clients) == 1): len(node.inputs[0].clients) == 1):
# Don't move this call to version outside the condition, it
# needs to be here.
if version() < 3000:
# No log-softmax before cudnn v3
return
softmax_node = node.inputs[0].owner softmax_node = node.inputs[0].owner
new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode) new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode)
return [new_softmax(softmax_node.inputs[0])] return [new_softmax(softmax_node.inputs[0])]
...@@ -1429,14 +1427,8 @@ class NoCuDNNRaise(Optimizer): ...@@ -1429,14 +1427,8 @@ class NoCuDNNRaise(Optimizer):
Raise a RuntimeError if cudnn can't be used. Raise a RuntimeError if cudnn can't be used.
""" """
try: for c in list_contexts():
dnn_available(None) if not dnn_available(c):
except ValueError:
# This is most likely due to get_context()
pass
# This means we will have a problem no matter what context.
if not dnn_available.avail:
# Make an assert error as we want Theano to fail, not # Make an assert error as we want Theano to fail, not
# just skip this optimization. # just skip this optimization.
raise AssertionError( raise AssertionError(
......
...@@ -60,6 +60,10 @@ def get_context(name): ...@@ -60,6 +60,10 @@ def get_context(name):
return _context_reg[name] return _context_reg[name]
def list_contexts():
return _context_reg.values()
# Private method # Private method
def _name_for_ctx(ctx): def _name_for_ctx(ctx):
for k, v in _context_reg: for k, v in _context_reg:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论