提交 5cecb66b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix list_contexts() so that it makes sense.

Also fix the logic in dnn_available() so that we always have an error message to present. Finally add the context name that failed in the NoCuDNNRaise error message.
上级 e20f07e6
......@@ -105,6 +105,7 @@ dnn_present.msg = None
def dnn_available(context_name):
if not dnn_present():
dnn_available.msg = dnn_present.msg
return False
ctx = get_context(context_name)
......@@ -121,6 +122,8 @@ def dnn_available(context_name):
return True
dnn_available.msg = None
class DnnBase(COp):
"""
......@@ -1424,7 +1427,7 @@ def local_log_softmax_dnn(node):
class NoCuDNNRaise(Optimizer):
def apply(self, fgraph):
"""
Raise a RuntimeError if cudnn can't be used.
Raise a error if cudnn can't be used.
"""
for c in list_contexts():
......@@ -1432,8 +1435,8 @@ class NoCuDNNRaise(Optimizer):
# Make an assert error as we want Theano to fail, not
# just skip this optimization.
raise AssertionError(
"cuDNN optimization was enabled, but Theano was not able"
" to use it. We got this error: \n" +
"cuDNN optimization was enabled, but Theano was not able "
"to use it for context " + c + ". We got this error: \n" +
dnn_available.msg)
gpu_seqopt.register("NoCuDNNRaise", NoCuDNNRaise(), 0, 'cudnn')
......
......@@ -61,7 +61,10 @@ def get_context(name):
def list_contexts():
return _context_reg.values()
"""
Return an iterable of all the registered context names.
"""
return _context_reg.keys()
# Private method
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论