提交 50c16b4a authored 作者: Frederic's avatar Frederic

Fix pycuda import to remove tests error.

上级 ef1b9f3a
...@@ -41,15 +41,21 @@ if (not hasattr(theano.sandbox, 'cuda') or ...@@ -41,15 +41,21 @@ if (not hasattr(theano.sandbox, 'cuda') or
import pycuda.autoinit import pycuda.autoinit
pycuda_available = True pycuda_available = True
else: else:
try:
import pycuda.driver import pycuda.driver
pycuda_available = True
except ImportError:
pass
if pycuda_available:
if hasattr(pycuda.driver.Context, "attach"): if hasattr(pycuda.driver.Context, "attach"):
pycuda.driver.Context.attach() pycuda.driver.Context.attach()
else: else:
# Now we always import this file when we call theano.sandbox.cuda.use # Now we always import this file when we call
# So this should not happen normally. # theano.sandbox.cuda.use. So this should not happen
# normally.
# TODO: make this an error. # TODO: make this an error.
warnings.warn("For some unknow reason, theano.misc.pycuda_init was not" warnings.warn("For some unknow reason, theano.misc.pycuda_init was"
" imported before Theano initialized the GPU and" " not imported before Theano initialized the GPU and"
" your PyCUDA version is 2011.2.2 or earlier." " your PyCUDA version is 2011.2.2 or earlier."
" To fix the problem, import theano.misc.pycuda_init" " To fix the problem, import theano.misc.pycuda_init"
" manually before using/initializing the GPU, use the" " manually before using/initializing the GPU, use the"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论