提交 31b44f44 authored 作者: Frederic Bastien's avatar Frederic Bastien

added theano_cuda_ndarray.use() that try to use the gpu for theano shared variable.

If the THEANO_GPU env var is set to a number, we will use this card number.
上级 aa8d0246
......@@ -7,14 +7,28 @@ from .var import (CudaNdarrayVariable,
import basic_ops
import opt
import cuda_ndarray
import theano.compile.sandbox
import logging, os
def use():
handle_shared_float32(True)
def handle_shared_float32(tf):
"""Set the CudaNdarrayType as the default handler for shared float32 arrays
Use use(tf) instead as this is a bad name.
"""
if tf:
try:
v=os.getenv("THEANO_GPU",0)
cuda_ndarray.gpu_init(int(v))
theano.compile.sandbox.shared_constructor(shared_constructor)
except RuntimeError, e:
logging.getLogger('theano_cuda_ndarray').warning("WARNING: Won't use the GPU as the initialisation failed."+str(e))
else:
raise NotImplementedError('removing our handler')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论