提交 5cecb66b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix list_contexts() so that it makes sense.

Also fix the logic in dnn_available() so that we always have an error message to present. Finally add the context name that failed in the NoCuDNNRaise error message.
上级 e20f07e6
...@@ -105,6 +105,7 @@ dnn_present.msg = None ...@@ -105,6 +105,7 @@ dnn_present.msg = None
def dnn_available(context_name): def dnn_available(context_name):
if not dnn_present(): if not dnn_present():
dnn_available.msg = dnn_present.msg
return False return False
ctx = get_context(context_name) ctx = get_context(context_name)
...@@ -121,6 +122,8 @@ def dnn_available(context_name): ...@@ -121,6 +122,8 @@ def dnn_available(context_name):
return True return True
dnn_available.msg = None
class DnnBase(COp): class DnnBase(COp):
""" """
...@@ -1424,7 +1427,7 @@ def local_log_softmax_dnn(node): ...@@ -1424,7 +1427,7 @@ def local_log_softmax_dnn(node):
class NoCuDNNRaise(Optimizer): class NoCuDNNRaise(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
""" """
Raise a RuntimeError if cudnn can't be used. Raise a error if cudnn can't be used.
""" """
for c in list_contexts(): for c in list_contexts():
...@@ -1432,8 +1435,8 @@ class NoCuDNNRaise(Optimizer): ...@@ -1432,8 +1435,8 @@ class NoCuDNNRaise(Optimizer):
# Make an assert error as we want Theano to fail, not # Make an assert error as we want Theano to fail, not
# just skip this optimization. # just skip this optimization.
raise AssertionError( raise AssertionError(
"cuDNN optimization was enabled, but Theano was not able" "cuDNN optimization was enabled, but Theano was not able "
" to use it. We got this error: \n" + "to use it for context " + c + ". We got this error: \n" +
dnn_available.msg) dnn_available.msg)
gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, 'cudnn') gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, 'cudnn')
......
...@@ -61,7 +61,10 @@ def get_context(name): ...@@ -61,7 +61,10 @@ def get_context(name):
def list_contexts(): def list_contexts():
return _context_reg.values() """
Return an iterable of all the registered context names.
"""
return _context_reg.keys()
# Private method # Private method
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论