提交 31b44f44 authored 作者: Frederic Bastien's avatar Frederic Bastien

added theano_cuda_ndarray.use() that try to use the gpu for theano shared variable.

If the THEANO_GPU env var is set to a number, we will use this card number.
上级 aa8d0246
...@@ -7,14 +7,28 @@ from .var import (CudaNdarrayVariable, ...@@ -7,14 +7,28 @@ from .var import (CudaNdarrayVariable,
import basic_ops import basic_ops
import opt import opt
import cuda_ndarray
import theano.compile.sandbox import theano.compile.sandbox
import logging, os
def use():
handle_shared_float32(True)
def handle_shared_float32(tf): def handle_shared_float32(tf):
"""Set the CudaNdarrayType as the default handler for shared float32 arrays """Set the CudaNdarrayType as the default handler for shared float32 arrays
Use use(tf) instead as this is a bad name.
""" """
if tf: if tf:
try:
v=os.getenv("THEANO_GPU",0)
cuda_ndarray.gpu_init(int(v))
theano.compile.sandbox.shared_constructor(shared_constructor) theano.compile.sandbox.shared_constructor(shared_constructor)
except RuntimeError, e:
logging.getLogger('theano_cuda_ndarray').warning("WARNING: Won't use the GPU as the initialisation failed."+str(e))
else: else:
raise NotImplementedError('removing our handler') raise NotImplementedError('removing our handler')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论